# Notebook to download images from unsplash dataset and save metadata to a database
# Imports
The code imports several libraries to work correctly. The libraries are as follows:

pickle: This module implements binary protocols for serializing and de-serializing a Python object structure. It is used here to save the metadata information in binary format.

os: This module provides a portable way of using operating system dependent functionality like reading or writing to the file system. It is used here to join paths and create directories.

zipfile: This module allows us to read and write ZIP archive files. It is used here to extract images from a compressed archive.

requests: This module allows us to send HTTP/1.1 requests using Python. It is used here to download images from a URL.

functools: This module provides various functions that can be used to create higher-order functions. It is used here to define the cached_property decorator.

pathlib: This module provides a higher-level interface to working with file system paths than the os module. It is used here to create directories.

tqdm: This module provides a progress bar that can be used to track the progress of a long-running operation.

json: This module provides methods for working with JSON data. It is used here to save the metadata information in JSON format.

sqlite3: This module provides a Python interface to SQLite databases. It is used here to create and interact with a SQLite database.

pandas: This module is used for data manipulation and analysis. It is used here to read CSV files and create a pandas DataFrame.

PIL: This module provides an interface for opening, manipulating, and saving many different image file formats. It is used here to extract metadata from images.

nest_asyncio: This module provides a way to run nested event loops in asyncio.

asyncio: This module provides infrastructure for writing single-threaded concurrent code using coroutines, multiplexing I/O access over sockets and other resources, running network clients and servers, and other related primitives.

aiohttp: This module provides an asynchronous HTTP client/server implementation using asyncio.

time: This module provides various time-related functions. It is used here to measure the time taken to download images.

subprocess: This module provides a way to spawn new processes, connect to their input/output/error pipes, and obtain their return codes. It is used here to check if the exiftool command line tool is installed.

tqdm_asyncio: This module provides a progress bar for asyncio applications. It is used here to track the progress of processing all images.

In [None]:
!pip install requests pathlib tqdm pandas pillow nest_asyncio aiohttp mysql-connector-python python-dotenv

In [None]:
import os
import zipfile
import requests
import functools
import pathlib
from tqdm import tqdm
import pandas as pd
import nest_asyncio
import aiohttp
import time
import asyncio
import mysql.connector
from mysql.connector import Error
import csv
from dotenv import load_dotenv
load_dotenv()

# Settings base variables and paths
For this project, we used the unsplash dataset, which is a large-scale image dataset. The dataset contains over 25,000 images.
The code sets the base variables and paths for the project. The variables are as follows:

In [None]:
# Set the base folder path for the project
output_path = "../output"
images_path = os.path.join(output_path, "images")
metadata_path = os.path.join(output_path, "metadata")
include_path = os.path.join(output_path, "include")

list_of_paths = [output_path, images_path, metadata_path, include_path]

# Set the base URL for the dataset
dataset_url = "https://unsplash.com/data/lite/latest"
# metadata mode (used to save metadata)
metadata_mode = "sqlite"

# Set the number of images to download
num_images = 1000

# Set SQL variables
sql_host = os.getenv("SQL_HOST")
sql_user = os.getenv("SQL_USER")
sql_password = os.getenv("SQL_PASSWORD")
sql_database = os.getenv("SQL_DATABASE")

# Create folder structure
The code creates the folder structure for the project. The folder structure is as follows:
- output
    - images
    - metadata
    - config

This method creates a folder with the given path if it doesn't already exist, It also outputs a message to inform the user if the folder was created or if it already exists.
This is useful for organizing and managing files in a project. By creating a folder to store data and resources, it keeps the working directory tidy and makes it easier to locate files. Additionally, by checking if the folder exists before creating it, it prevents the program from overwriting existing data or throwing an error.

In [None]:
def create_folder(path):
    """
    This function creates a folder at the specified path.
    If the folder already exists, it will print a message saying so.
    If there is an error creating the folder, it will print the error message.

    Parameters:
        :param path (str): The path of the folder to be created.

    Returns:
    None
    """
    try:
        # Use os.mkdir to create the folder at the specified path
        os.mkdir(path)
        print(f"Folder {path} created")
    except FileExistsError:
        # If the folder already exists, print a message saying so
        print(f"Folder {path} already exists")
    except Exception as e:
        # If there is an error creating the folder, print the error message
        print(f"Error creating folder {path}: {e}")

# Create the folder structure
This method initializes a list of folders by calling the create_folder method for each folder in the list.
The purpose of this method is to make sure that all necessary folders exist before the program continues its execution.
If a folder does not exist, the create_folder method will create it. If a folder already exists, the method will simply print a message indicating that the folder already exists. In case of any other error, the method will print the error message.

In [None]:
def init_folder(folder_names: list):
    for folder_name in folder_names:
        create_folder(folder_name)

In [None]:
init_folder(list_of_paths)

# Define methods for downloading the dataset
The following code block is a method to download a file from a given URL and save it to a specified filename.

The method starts by creating a session (s = requests.Session()) and then mounting it to the URL (s.mount(url, requests.adapters.HTTPAdapter(max_retries=3))). This sets the maximum number of retries to 3 if the connection to the URL fails.

Then, the method makes a GET request to the URL (r = s.get(url, stream=True, allow_redirects=True)) and checks if it returns a successful response (r.raise_for_status()). If there was an HTTP error during the request, the error message is printed (print(f"HTTP error occurred while downloading dataset: {e}")).

The method also checks the file size specified in the response headers and assigns it to the variable file_size (file_size = int(r.headers.get('Content-Length', 0))). If the file size is 0, a default value of "(Unknown total file size)" is assigned to the variable desc; otherwise, the variable desc is left empty.

Next, the method resolves the file path and creates a directory if it doesn't already exist (path.parent.mkdir(parents=True, exist_ok=True)). The method then creates a tqdm progress bar to show the download progress (with tqdm.tqdm(total=file_size, unit='B', unit_scale=True, desc=desc) as pbar:).

Finally, the method writes the contents of the file to disk in chunks (for chunk in r.iter_content(chunk_size=1024):), updating the progress bar for each chunk that is written to disk (pbar.update(len(chunk))). If an error occurred during the download, a message with the error is printed (print(f"Error occurred while downloading dataset: {e}")). The file path is returned when the method is finished.

In [None]:
def download(url, filename):
    """
    This download a file from a given URL and save it to a specified filename.

    Parameters:
        :param url (str): The URL of the file to be downloaded.
        :param filename (str): The filename to save the file as.

    Returns:
    path (str): The path of the downloaded file.
    """
    try:
        # Create a session object to persist the state of connection
        s = requests.Session()
        # Retry connecting to the URL up to 3 times
        s.mount(url, requests.adapters.HTTPAdapter(max_retries=3))
        # Send a GET request to the URL to start the download
        r = s.get(url, stream=True, allow_redirects=True)
        # Raise an error if the response is not 200 OK
        r.raise_for_status()
        # Get the file size from the Content-Length header, default to 0 if not present
        file_size = int(r.headers.get('Content-Length', 0))
        # Get the absolute path to the target file
        path = pathlib.Path(filename).expanduser().resolve()
        # Create parent directories if they don't exist
        path.parent.mkdir(parents=True, exist_ok=True)
        # Set the description to display while downloading, "(Unknown total file size)" if file size is 0
        desc = "(Unknown total file size)" if file_size == 0 else ""
        # Enable decoding the response content
        r.raw.read = functools.partial(r.raw.read, decode_content=True)
        # Use tqdm to display the download progress
        with tqdm(total=file_size, unit='B', unit_scale=True, desc=desc) as pbar:
            # Open the target file in binary write mode
            with path.open("wb") as f:
                # Write each chunk of data from the response to the file
                for chunk in r.iter_content(chunk_size=1024):
                    f.write(chunk)
                    pbar.update(len(chunk))
        # Return the path to the downloaded file
        return path
    # Handle HTTP error if the response is not 200 OK
    except requests.exceptions.HTTPError as e:
        print(f"HTTP error occurred while downloading dataset: {e}")
    # Handle any other exceptions that might occur while downloading the file
    except Exception as e:
        print(f"Error occurred while downloading dataset: {e}")

# Download the dataset
The following code block downloads the dataset from the URL and saves it to the specified filename. The method also prints a message to inform the user that the download is complete.

In [None]:
def download_dataset(dataset_url, image_path):
    """
    Downloads the dataset from the given URL, unzips it, and stores the images in the specified image path.

    Args:
        :param dataset_url (str): URL of the dataset to be downloaded
        :param image_path (str): Path to store the images after unzipping the dataset
    """
    # Check if the dataset has already been downloaded
    # Check if the archive.zip file exists or if the images folder is empty
    if not os.path.exists('archive.zip'):
        # Download the dataset from the given url
        download(dataset_url, 'archive.zip')
        print("Dataset downloaded!")
        try:
            # Extract the contents of the archive.zip to the specified image path
            with zipfile.ZipFile('archive.zip', 'r') as zip_ref:
                zip_ref.extractall(image_path)
            print("Dataset unzipped")
        except Exception as e:
            print(f"Error occurred while unzipping dataset: {e}")
        try:
            # Remove the archive.zip file
            os.remove('archive.zip')
            print("archive.zip removed")
        except Exception as e:
            print(f"Error occurred while removing archive.zip: {e}")

In [None]:
download_dataset(dataset_url, images_path)

Allow the notebook to run asynchronously

In [None]:
nest_asyncio.apply()

In [None]:
# Read photo.tsv file in images folder
photo_df = pd.read_csv(os.path.join(images_path, 'photos.tsv000'), sep='\t')
# read photo_image_url column and photo_id in index
photo_df = photo_df[['photo_id', 'photo_image_url']]

print(photo_df.head())

This method downloads an image from a given URL using an asynchronous HTTP client library called aiohttp. The downloaded image is saved to a file on the local file system with a filename in the format "image_#index.jpg", where #index is a given integer value.

The method takes four arguments:

session: an instance of an aiohttp client session that manages HTTP requests and responses.
url: a string representing the URL of the image to download.
i: an integer representing the index of the image to download.
err_cnt: an optional integer representing the number of times that the download has failed due to a client error. If this is not provided, it defaults to 0.
The method first checks whether an error count was provided, and if not, sets it to 0. It then attempts to download the image using the aiohttp session.get() method, which returns a response object. The async with statement ensures that the response is properly handled and that the connection to the server is closed when the request is completed.

Once the response is obtained, the method constructs a filename using the given index value, and writes the image content to a file with that filename in binary mode. It then prints a message indicating that the image was downloaded and saved to the specified filename.

If the download fails due to a client error (e.g., a network timeout), the method catches the error using an except block, prints an error message indicating the URL that failed and the error that occurred, and then waits for 10 seconds before retrying the download. If the error count reaches 10, the method returns without attempting to download the image again.

If the download fails due to a server error (e.g., a 404 Not Found response), the exception is not caught and will propagate up the call stack.

If the download is successful, the method returns nothing.

In [None]:
async def download_image(session: aiohttp.ClientSession, url: str, i: int, err_cnt=None):
    """
    Downloads an image from the given URL using an aiohttp client session and saves it to the local file system.

    Args:
        session: An aiohttp client session that manages HTTP requests and responses.
        url: The URL of the image to download.
        i: An integer representing the index of the image to download.
        err_cnt: An optional integer representing the number of times that the download has failed due to a client error.
                 If not provided, it defaults to 0.

    Raises:
        This method does not raise any exceptions.

    Returns:
        None.
    """
    if err_cnt is None:
        err_cnt = 0
    try:
        async with session.get(url) as response:
            filename = os.path.join(images_path, "image_" + str(i) + ".jpg")
            with open(filename, 'wb') as f:
                f.write(await response.content.read())
            print(f"Downloaded {url} to {filename} idx: {i}")
    except aiohttp.ClientError as e:
        print(f"Error occurred while downloading {url}: {e}")
        if err_cnt == 10:
            return
        await asyncio.sleep(10)
        err_cnt += 1
        await download_image(session, url, i, err_cnt)

This method, download_images, is a function that downloads a list of images from the web using the aiohttp library and saves them to the local file system. The method takes two arguments: image_urls, a list of URLs where the images are hosted, and images_ids, a list of identifiers for the images that will be used to name the files when they are saved locally.

The method starts by creating an aiohttp.ClientSession object, which is used to manage HTTP requests and responses. It then initializes an empty list tasks and creates a semaphore object with a maximum value of 5000. The semaphore is used to limit the number of concurrent downloads and prevent overloading the server.

Next, the method loops through the image_urls list using the built-in enumerate function to keep track of the index of each URL. For each URL, the method tries to acquire a permit from the semaphore to start a new download. If the maximum number of concurrent downloads has been reached, the method blocks until a permit becomes available.

Once a permit is acquired, the method creates a new task using the asyncio.ensure_future function and adds it to the tasks list. The task represents the asynchronous download of the image from the current URL using the download_image method. A callback function is also added to the task that releases the semaphore permit when the task completes, so that another download can start.

If an error occurs while trying to download an image, the method prints an error message and releases the semaphore permit. The method then continues with the next URL in the list.

After all tasks have been created and added to the tasks list, the method waits for them to complete using the asyncio.wait function. Finally, it gathers all the completed tasks using the asyncio.gather function, which ensures that all tasks have finished before the method returns.

Overall, the download_images method is an efficient and asynchronous way to download a large number of images from the web and save them to the local file system.

In [None]:
async def download_images(image_urls, images_ids):
    """
    Downloads a list of images from the given URLs using an aiohttp client session and saves them to the local file system.

    Args:
        image_urls: A list of strings representing the URLs of the images to download.
        images_ids: A list of integers representing the indices of the images to download.

    Raises:
        This method does not raise any exceptions.

    Returns:
        None.
    """
    # Create a new aiohttp client session to manage HTTP requests and responses
    async with aiohttp.ClientSession() as session:
        tasks = []  # Create an empty list to hold the tasks that will download the images
        semaphore = asyncio.Semaphore(5000)  # Create a semaphore to limit the number of concurrent downloads
        # Loop through the image URLs and create a new task for each one
        for i, url in enumerate(image_urls):
            try:
                await semaphore.acquire()  # Acquire a permit from the semaphore to limit concurrency
                #url = url + "?w=1000&fm=jpg&fit=max"  # Append query parameters to resize and optimize the image
                task = asyncio.ensure_future(download_image(session, url, images_ids[i]))  # Create a new download task
                task.add_done_callback(
                    lambda x: semaphore.release())  # Release the semaphore permit when the task completes
                tasks.append(task)  # Add the task to the list of download tasks
            except Exception:
                print(f"Error occurred while downloading {url}")
                semaphore.release()  # Release the semaphore permit if an exception occurs
        # Wait for all download tasks to complete
        await asyncio.wait(tasks)
        # Gather the results of all download tasks (not necessary because the tasks have already completed)
        await asyncio.gather(*tasks)

In [None]:
# Get the list of image urls and image ids
image_urls = photo_df['photo_image_url'].values.tolist()[:num_images]
# img id are from 0 to size of the list
images_ids = [i for i in range(len(image_urls))][:num_images]
# filter by looking if the image already exist in fact of the image_id is already in the folder
# Loop on the image_id and check if the image exist in the folder
image_urls = [url for url, image_id in zip(image_urls, images_ids) if
              not os.path.exists(os.path.join(images_path, "image_" + str(image_id) + ".jpg"))]
print(f"Number of images to download: {len(image_urls)}")

In [None]:
# Split the list of image urls into chunks of max and add a timeout of 30 seconds
chunks = [image_urls[i:i + 5000] for i in range(0, len(image_urls), 5000)]
start_t = time.time()
loop = None
for i, chunk in enumerate(chunks):
    start = time.time()
    try:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        loop.run_until_complete(download_images(chunk, images_ids[i * 5000:(i + 1) * 5000]))
    except Exception as e:
        print(f"Error occurred while downloading chunk {i}: {e}")
    finally:
        loop.close()
        print(f"[Chunk {i}] Downloaded {len(chunk)} images in {time.time() - start} seconds")

print(f'Downloaded {len(image_urls)} images in {time.time() - start_t} seconds')

Clear the images folder from all files except images

In [None]:
# Remove all files except images
for file in os.listdir(images_path):
    if file.endswith('.jpg'):
        continue
    else:
        try:
            # Don't delete TERMS.md
            if file == 'TERMS.md':
                continue
            os.remove(os.path.join(images_path, file))
        except Exception as e:
            continue

# Define methods to get all the image paths
The get_all_images method is used to retrieve all images present in the specified image path. It uses the os.walk function to traverse through all subdirectories within the image path and collects the file names that end with either '.png' or '.jpg' extensions. The full path of each image is then generated by joining the root directory and the file name. The method returns a list of all images' full paths. In case of any error, an error message is printed and an empty list is returned.

In [None]:
def get_all_images(path):
    """Get all images from the given path.

    Args:
    param: image_path (str): path to the directory containing the images.

    Returns:
    - list: a list of full path to all the images with png or jpg extensions.
    - empty list: an empty list if an error occurred while fetching images.
    """
    try:
        # use os.walk to traverse all the subdirectories and get all images
        return [os.path.join(root, name)
                for root, dirs, files in os.walk(path)
                for name in files
                if name.endswith((".png", ".jpg"))]
    except Exception as e:
        # return an empty list and log the error message if an error occurred
        print(f"An error occurred while fetching images: {e}")
        return []

# Define methods to get metadata
The goal of the get_metadata method is to extract metadata information from a list of image files and return it in a dictionary format. This method uses the ExifTool software to extract the metadata information from the images. The input parameter is a string containing all image file paths separated by a space. The output of the method is a dictionary containing the metadata information of the images. If an error occurs for any image, the metadata for that image will be None. This method is implemented as an asynchronous coroutine using the asyncio module in Python.


This function takes a list of dictionaries containing metadata information of images and generates a list of SQL requests to insert the metadata into a database.

The function first creates an empty list to store the SQL requests. It then loops over each metadata dictionary in the input list using the tqdm function to provide a progress bar. For each metadata dictionary, the function extracts the filename of the image and then loops over all the items in the dictionary.

For each key-value pair in the metadata dictionary, the function creates an SQL request to insert the metadata into the database. The SQL request is in the form of a string that contains the filename, key, and value of the metadata item. The function adds each SQL request to the list of SQL requests.

After processing all the metadata dictionaries, the function returns the list of SQL requests.

In [None]:
def gen_sql_requests(metadatas):
    """
    This function generates a list of SQL requests to insert metadata into a database.

    Parameters:
    metadatas (list): A list of dictionaries containing the metadata information of the images.

    Returns:
    list: A list of SQL requests to insert metadata into a database.
    """
    # Create a list to store SQL requests
    sql_requests = []
    # Loop over all metadata
    for i, metadata in enumerate(metadatas):
        try:
            # Get the filename of the image
            filename = metadatas[i]['filename']

            # Loop over all metadata items
            for key, value in metadatas[i].items():
                # Create SQL request to insert metadata into database
                # replace " by space
                value = value.replace('"', ' ')
                # replace ' by space
                value = value.replace("'", ' ')

                sql_request = f"INSERT INTO metadata VALUES ('{filename}', '{key}', '{value}')"
                # Add SQL request to list
                sql_requests.append(sql_request)

        except Exception as e:
            # Print an error message if an error occurs
            print("An error occurred while generating SQL requests: ", e)
            continue
    # Return the list of SQL requests
    return sql_requests

This method executes a SQL query on a SQLite database. The method takes a single parameter, queries, which should be a valid SQL queries that is compatible with the SQLite database.

The method first establishes a connection to the SQLite database using the connect() method of the sqlite3 module in Python. It then executes the SQL query using the execute() method of the connection object. After executing the query, the method commits the changes to the database using the commit() method of the connection object, and then closes the connection using the close() method of the connection object.

This method is typically used to insert metadata information into a SQLite database. It assumes that the SQLite database already exists and is located in the metadata_path directory.

In [None]:
def create_server_connection(host_name, user_name, user_password, db_name):
    connection = None
    try:
        connection = mysql.connector.connect(
            host=host_name,
            user=user_name,
            passwd=user_password,
            database=db_name
        )
        print("MySQL Database connection successful")
    except Error as err:
        print(f"Error: '{err}'")

    return connection

This is an asynchronous coroutine get_all_metadata that extracts metadata from all images in a directory and saves the metadata information in either pickle or JSON format. The function takes two parameters: image_path which is the path to the directory where the images are stored, and metadata_path which is the path to the directory where the metadata will be saved.

Firstly, the function retrieves a list of all images in the directory using the get_all_images method. It then creates a Semaphore object with a limit of 5000 to limit the number of simultaneous coroutines to 5000.

A progress bar is created using the tqdm_asyncio library to track the progress of processing all images. A list of coroutines is created using a list comprehension with each coroutine being an instance of the get_metadata method applied to each image file in the directory.

The coroutines are then executed concurrently using the asyncio.as_completed method, with a maximum of 5000 coroutines being executed at a time, and their results are appended to a metadatas list. For each successfully extracted metadata, a SQL query is generated and executed to insert the metadata into the database using the gen_sql_requests and execute_query functions.

Once all coroutines are completed, the metadata information is saved into the database. If an error occurs during the metadata extraction process, the None value is returned for that image and the error message is printed.

In [None]:
async def get_all_metadata(images_path):
    """
    This coroutine extracts metadata from all images in a directory and saves the metadata information in either pickle or json format.

    Parameters:
    image_path (str): The path to the directory where the images are stored.
    metadata_path (str): The path to the directory where the metadata will be saved.

    Returns:
    None
    """
    # Use the binary exifextract from include path
    binary = include_path + '/exifextract'
    command = [binary, images_path, metadata_path + '/metadata.csv']
    import subprocess
    # execute command
    popen = subprocess.Popen(command, stdout=subprocess.PIPE)
    popen.wait()

    # wait for the process to terminate
    output, error = popen.communicate()

    while popen.poll() is None:
        time.sleep(0.1)

    # check if the process terminated successfully
    if popen.returncode != 0:
        raise subprocess.CalledProcessError(popen.returncode, command)

    # load metadata from csv
    with open(metadata_path + '/metadata.csv', 'r') as f:
        reader = csv.reader(f)
        metadata = list(reader)
        header = metadata[0]

    metadata = metadata[1:]
    metadata_dict = {}
    for i, row in enumerate(metadata):
        metadata_dict[i] = {}
        for j in range(1, len(header)):
            metadata_dict[i][header[j]] = row[j]
        # add filename to metadata
        # remove ' from row[0]
        row[0] = row[0].replace("'", '')
        metadata_dict[i]['filename'] = row[0]


    # save metadata to database
    print("Generating SQL requests...")
    queries = gen_sql_requests(metadata_dict)
    # save to file queries.sql
    with open(metadata_path + '/queries.sql', 'w') as f:
        f.write(";\n".join(queries))


    print("Saving metadata to database...")
    # TODO: Save to database
    conn = create_server_connection(sql_host, sql_user, sql_password, sql_database)
    # Execute file to database
    cursor = conn.cursor()
    #cursor.execute("DROP TABLE IF EXISTS metadata")
    #cursor.execute("CREATE TABLE metadata (filename VARCHAR(255), key VARCHAR(255), value VARCHAR(255))")
    cursor.execute("LOAD DATA LOCAL INFILE '" + metadata_path + "/queries.sql' INTO TABLE metadata FIELDS TERMINATED BY ';' LINES TERMINATED BY '\n'")
    conn.commit()
    conn.close()

In [None]:
asyncio.run(get_all_metadata(images_path))

# How to look at the metadata (sqlite format)
This is the way to look at the metadata information of an image in sqlite database format.

In [None]:
conn = create_server_connection(sql_host, sql_user, sql_password, sql_database)

cursor = conn.cursor()
cursor.execute("SELECT * FROM metadata WHERE filename = 'image_0.jpg'")

result = cursor.fetchall()
for x in result:
    print(x)

# Notebook 2

# Imports
This code imports several libraries in order to perform some data processing tasks. The libraries used are:
import pickle: imports the pickle module that provides a way to serialize and deserialize Python objects.

import os: imports the os module that provides a way to interact with the operating system.

import sqlite3: imports the sqlite3 module that provides a way to work with SQLite databases.

import json: imports the json module that provides a way to encode and decode JSON data.

import cv2: imports the cv2 module which is an OpenCV library for image processing and computer vision tasks.

from sklearn.cluster import MiniBatchKMeans: imports the MiniBatchKMeans class from the sklearn.cluster module which provides a way to perform KMeans clustering on a large dataset.

from tqdm import tqdm: imports the tqdm module which provides a progress bar for long-running operations.

from transformers import DetrImageProcessor, DetrForObjectDetection: imports the DetrImageProcessor and DetrForObjectDetection classes from the transformers module which provides a way to perform object detection using the DETR (DEtection TRansformer) model.

import torch: imports the torch module which is a PyTorch library for machine learning and deep learning tasks.

from PIL import Image: imports the Image class from the PIL module which provides a way to manipulate and analyze image data.

With these libraries, you should be able to perform a wide range of data processing and analysis tasks.

In [None]:
!pip install opencv-python scikit-learn tqdm transformers torch Pillow python-dotenv

In [None]:
import os
import cv2
from sklearn.cluster import MiniBatchKMeans
from tqdm import tqdm
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import mysql.connector
from mysql.connector import Error
from dotenv import load_dotenv
load_dotenv()

# Settings base variables and paths
This code sets up the file structure and URL's for a project that uses data from an image dataset.

In [None]:
# Set the base folder path for the project
output_path = "../output"
images_path = os.path.join(output_path, "images")
metadata_path = os.path.join(output_path, "metadata")
config_path = os.path.join(output_path, "config")

list_of_paths = [output_path, images_path, metadata_path, config_path]

# Set the base URL for the dataset
metadata_extension = "sqlite"

# Set SQL variables
sql_host = os.getenv("SQL_HOST")
sql_user = os.getenv("SQL_USER")
sql_password = os.getenv("SQL_PASSWORD")
sql_database = os.getenv("SQL_DATABASE")

# Create folder structure
The code creates the folder structure for the project. The folder structure is as follows:
- output
    - images
    - metadata
    - config

This method creates a folder with the given path if it doesn't already exist, It also outputs a message to inform the user if the folder was created or if it already exists.
This is useful for organizing and managing files in a project. By creating a folder to store data and resources, it keeps the working directory tidy and makes it easier to locate files. Additionally, by checking if the folder exists before creating it, it prevents the program from overwriting existing data or throwing an error.

In [None]:
def create_folder(path):
    """
    This function creates a folder at the specified path.
    If the folder already exists, it will print a message saying so.
    If there is an error creating the folder, it will print the error message.

    Parameters:
        :param path (str): The path of the folder to be created.

    Returns:
    None
    """
    try:
        # Use os.mkdir to create the folder at the specified path
        os.mkdir(path)
        print(f"Folder {path} created")
    except FileExistsError:
        # If the folder already exists, print a message saying so
        print(f"Folder {path} already exists")
    except Exception as e:
        # If there is an error creating the folder, print the error message
        print(f"Error creating folder {path}: {e}")

# Create the folder structure
This method initializes a list of folders by calling the create_folder method for each folder in the list.
The purpose of this method is to make sure that all necessary folders exist before the program continues its execution.
If a folder does not exist, the create_folder method will create it. If a folder already exists, the method will simply print a message indicating that the folder already exists. In case of any other error, the method will print the error message.

In [None]:
def init_folder(folder_names: list):
    for folder_name in folder_names:
        create_folder(folder_name)

In [None]:
init_folder(list_of_paths)

# Define methods to get all the image paths
The get_all_images method is used to retrieve all images present in the specified image path. It uses the os.walk function to traverse through all subdirectories within the image path and collects the file names that end with either '.png' or '.jpg' extensions. The full path of each image is then generated by joining the root directory and the file name. The method returns a list of all images' full paths. In case of any error, an error message is printed and an empty list is returned.

In [None]:
def get_all_images(path):
    """Get all images from the given path.

    Args:
    param: image_path (str): path to the directory containing the images.

    Returns:
    - list: a list of full path to all the images with png or jpg extensions.
    - empty list: an empty list if an error occurred while fetching images.
    """
    try:
        # use os.walk to traverse all the subdirectories and get all images
        return [os.path.join(root, name)
                for root, dirs, files in os.walk(path)
                for name in files
                if name.endswith((".png", ".jpg"))]
    except Exception as e:
        # return an empty list and log the error message if an error occurred
        print(f"An error occurred while fetching images: {e}")
        return []

# Facebook DETR model (detr-resnet-101)

The detect_with_transformers function takes an image file path as an input, then uses a pre-trained model called DEtection TRansformer (DETR) to detect objects within the image.

The function first opens the input image using the Python Imaging Library (PIL) Image.open method. Then it instantiates two components of the DETR model: a DetrImageProcessor and a DetrForObjectDetection model. The DetrImageProcessor is responsible for processing the input image into a format that can be fed into the DetrForObjectDetection model. The DetrForObjectDetection model then takes the processed image and performs object detection by predicting bounding boxes and class labels for each detected object.

Once the model has made its predictions, the function uses the processor.post_process_object_detection method to convert the bounding box and class label predictions into a format that is compatible with the Common Objects in Context (COCO) dataset. This conversion is necessary in order to use the COCO API, which provides a common framework for evaluating object detection models.

The function then filters the detected objects by only keeping those with a confidence score above a certain threshold (0.9 in this case), and extracts the corresponding class labels. Finally, the function prints out a message for each detected object, indicating its class label, confidence score, and location within the image. The function returns a list of the detected object class labels.

In [None]:
def detect_with_transformers(image):
    """
    This function detects objects in an image using the DETR (DEtection TRansformer) model by Facebook.

    Args:
    image: A string representing the path of the image to be processed.

    Returns:
    A list containing the labels of the detected objects in the image.

    Raises:
    None.
    """
    #image = Image.open(image)
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101")
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101")
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)

    # convert outputs (bounding boxes and class logits) to COCO API
    # let's only keep detections with score > 0.9
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
    labels = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [round(i, 2) for i in box.tolist()]
        labels.append(model.config.id2label[label.item()])
        #print(
        #    f"Detected {model.config.id2label[label.item()]} with confidence "
        #    f"{round(score.item(), 3)} at location {box}"
        #)
    return labels

# Save  metadata

The function save_metadata allows you to save metadata information of an image in either pickle, json, or sqlite format. The function takes four parameters: metadata, img_name, metadata_path, and save_format.

metadata is a dictionary that contains the metadata information of an image. img_name is a string that represents the file name of the image. metadata_path is a string that specifies the path to the directory where the metadata will be saved. save_format is an optional parameter that specifies the format in which the metadata will be saved. The default value is pickle.

The function saves the metadata in the specified format. If save_format is set to pickle, the metadata is saved in the pickle format. If save_format is set to json, the metadata is saved in the json format. If save_format is set to sqlite, the metadata is saved in the sqlite database.

If an error occurs while saving the metadata, the function will print an error message indicating the image name and the error that occurred.

The function does not return any value.

In [None]:
def create_server_connection(host_name, user_name, user_password, db_name):
    connection = None
    try:
        connection = mysql.connector.connect(
            host=host_name,
            user=user_name,
            passwd=user_password,
            database=db_name
        )
    except Error as err:
        print(f"Error: '{err}'")

    return connection

In [None]:
def save_metadata(metadata, img_name):
    """
    This function saves the metadata information of an image in either pickle or json format.
    Parameters:
    metadata (dict): The metadata information of an image.
    img_name (str): The file name of the image.
    metadata_path (str): The path to the directory where the metadata will be saved.
    save_format (str): The format in which the metadata will be saved. The default is 'pickle'.

    Returns:
    None
    """
    try:
        # Get only the file name of the image
        img_name = os.path.basename(img_name)
        # img_name = img_name + '.jpg'
        img_name = img_name.split('.')[0] + '.jpg'
        # Open a connection to the database
        conn = create_server_connection(sql_host, sql_user, sql_password, sql_database)
        # Create a cursor
        c = conn.cursor()
        # Insert the metadata into the table
        for key, value in metadata.items():
            # Convert key, value to string
            key = str(key)
            value = str(value)
            # Check if the key is already in the table
            c.execute("SELECT * FROM metadata WHERE filename=? AND key=?", (img_name, key))
            # If the key is already in the table, update the value
            if c.fetchone():
                c.execute("UPDATE metadata SET value=? WHERE filename=? AND key=?", (value, img_name, key))
                # Commit the changes
                conn.commit()
            # If the key is not in the table, insert the key, value pair
            else:
                c.execute("INSERT INTO metadata VALUES (?, ?, ?)", (img_name, key, value))
                # Commit the changes
                conn.commit()
        # Close the connection
        conn.close()

    except Exception as e:
        # print an error message if an error occurs
        print(f"An error occurred while saving metadata for {img_name}: {str(e)}")

# Read SQLite metadata

The read_sqlite method is a function used to read metadata information from a SQLite database. The method takes two parameters: metadata_path, the path to the directory where the metadata is saved, and filename, the name of the file for which the metadata is to be retrieved.

The method starts by connecting to the SQLite database located at metadata_path/metadata.db using the sqlite3.connect method. A cursor is then created to allow interaction with the database. If the metadata table doesn't exist in the database, it is created.

The metadata for the specified filename is retrieved from the database by executing a SQL query that selects all rows where the filename column is equal to the filename parameter. The retrieved metadata is stored in a dictionary, where the keys are taken from the key column and the values from the value column.

Finally, the database connection is closed and the metadata dictionary is returned as the result of the function.

In [None]:
def read_mariadb(filename):
    # Open a connection to the database
    conn = create_server_connection(sql_host, sql_user, sql_password, sql_database)
    # Create a cursor
    c = conn.cursor()
    # Create a table if it doesn't exist : filename, mkey, mvalue
    c.execute('''CREATE TABLE IF NOT EXISTS metadata (filename text, mkey text, mvalue text)''')
    # Insert the metadata into the table
    c.execute("SELECT * FROM metadata WHERE filename=?", (filename,))
    # If the key is already in the table, update the value
    metadata = {}
    for row in c.fetchall():
        metadata[row[1]] = row[2]
    # Close the connection
    conn.close()
    return metadata

# Set tags in metadata
This function "update_tags" is used to run the YOLOv3 algorithm on a set of images, update the metadata of each image with the detected labels (tags) and save the updated metadata.

The function takes 3 parameters:

images: a list of file paths for the images that need to be processed.
metadata_path: a file path to the directory where the metadata files are stored.
save_format: the format of the metadata files. Can be either 'pickle' or 'sqlite'.
The function uses the tqdm library to display a progress bar for the image processing. For each image, the function tries to retrieve its metadata based on the save_format. If the metadata file format is 'sqlite', the function calls the read_sqlite function to retrieve the metadata. If the metadata file format is 'pickle', the function reads the metadata file directly.

If the metadata already contains a "tags" key, it means that the image has already been processed and its metadata has been updated with the labels, so the function skips that image.

The function then calls the detect function to run the YOLOv3 algorithm on the image and retrieve the labels (tags). The labels are added to the metadata under the "tags" key.

Finally, the function calls the save_metadata function to save the updated metadata. If an error occurs while processing an image (e.g. the metadata file is not found), the function prints an error message and continues processing the next image.

In [None]:
def gen_sql(metadata):
    """Generate SQL query to insert metadata into the database"""
    sql = "INSERT INTO metadata (filename, mkey, mvalue) VALUES "
    for filename, data in metadata.items():
        for key, value in data.items():
            sql += f"('{filename}', '{key}', '{value}'),"

    return sql[:-1]


def update_tags(images):
    # Run the YOLOv3 algorithm on each image
    # display progress bar in the first thread only
    metadata = {}
    for image in tqdm(images, desc="Updating tags"):
        # read pickle file from ../output/metadata/file_name.pkl
        file_name = os.path.basename(image)
        file_name, ext = file_name.split(".")
        try:
            image = Image.open(image)
            # resize image to 416x416
            image = image.resize((416, 416))
            labels = detect_with_transformers(image)
            image.close()

            # Remove duplicates from labels
            labels = list(set(labels))
            # add labels to metadata
            metadata[file_name + '.jpg'] = {"tags": labels}
        except FileNotFoundError:
            print("File not found: ", file_name)
            continue
        except Exception as e:
            continue
    sql = gen_sql(metadata)
    # Save metadata to tags_queries.sql
    with open("tags_queries.sql", "w") as f:
        f.write(sql)

    # TODO: Save to database
    conn = create_server_connection(sql_host, sql_user, sql_password, sql_database)
    # Execute file to database
    cursor = conn.cursor()
    #cursor.execute("DROP TABLE IF EXISTS metadata")
    #cursor.execute("CREATE TABLE metadata (filename VARCHAR(255), key VARCHAR(255), value VARCHAR(255))")
    cursor.execute(
        "LOAD DATA LOCAL INFILE '" + os.getcwd() + "/tags_queries.sql' INTO TABLE metadata FIELDS TERMINATED BY ',' ENCLOSED BY '(' LINES TERMINATED BY ')';")
    conn.commit()
    conn.close()


In [None]:
# Get the list of images
images = os.listdir(images_path)
images = [os.path.join(images_path, image) for image in images]

update_tags(images, metadata_path, save_format='sqlite')

    ### Now, find dominant colors in the images
The functions rgb_to_hex and find_dominant_colors are used to find the dominant colors in an image.

The function rgb_to_hex takes in an RGB array with 3 values, and returns the hexadecimal representation of the color. This can be useful for formatting colors in a standardized way, as hexadecimal codes are widely used in web development and other applications.

The function find_dominant_colors takes in an image and optional parameters k and image_processing_size. The k parameter specifies the number of dominant colors to return, with a default value of 4. The image_processing_size parameter allows you to resize the image to a smaller size, to speed up the processing, if desired.

The image is first converted from BGR to RGB, and then reshaped into a list of pixels. The KMeans algorithm is used to cluster the pixels into k clusters, and the most popular clusters are identified. The color values for each of the k clusters are converted to hexadecimal representation and returned as a list, along with the percentage of the image covered by each color.

In [None]:
def rgb_to_hex(rgb):
    return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2]))

In [None]:
def find_dominant_colors(image_path, k=4, downsample=2, resize=(200, 200)):
    # Load image and convert to RGB
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Downsample the image
    image = cv2.resize(image, (image.shape[1] // downsample, image.shape[0] // downsample))

    # Resize the image if requested
    if resize is not None:
        image = cv2.resize(image, resize)

    # Flatten the image
    image_flat = image.reshape((image.shape[0] * image.shape[1], 3))

    # Cluster the pixels using KMeans and find percentage of image covered by each color
    clt = MiniBatchKMeans(n_clusters=k, n_init=10, batch_size=100, random_state=42)
    labels = clt.fit_predict(image_flat)

    # Count the number of pixels assigned to each cluster
    counts = np.bincount(labels)

    # Calculate the percentage of pixels assigned to each cluster
    percentages = counts / len(labels)

    # Get the dominant colors
    dominant_colors = clt.cluster_centers_

    # Convert to hexadecimal format
    dominant_colors_hex = [rgb_to_hex(color) for color in dominant_colors]

    # Combine the dominant colors and their percentages into a array of tuples
    result = list(zip(dominant_colors_hex, percentages))

    return result

This function takes a list of dictionaries containing metadata information of images and generates a list of SQL requests to insert the metadata into a database.

The function first creates an empty list to store the SQL requests. It then loops over each metadata dictionary in the input list using the tqdm function to provide a progress bar. For each metadata dictionary, the function extracts the filename of the image and then loops over all the items in the dictionary.

For each key-value pair in the metadata dictionary, the function creates an SQL request to insert the metadata into the database. The SQL request is in the form of a string that contains the filename, key, and value of the metadata item. The function adds each SQL request to the list of SQL requests.

After processing all the metadata dictionaries, the function returns the list of SQL requests.

In [None]:
def gen_sql_requests(filenames, colors):
    """
    This function generates a list of SQL requests to insert metadata into a database.

    Returns:
    list: A list of SQL requests to insert metadata into a database.
    """
    # Create a list to store SQL requests
    sql_requests = []

    # Loop over all metadata
    for filename, color in tqdm(zip(filenames, colors), desc="Generating SQL requests"):
        # Create SQL request to insert metadata into database (filename, key, value)
        # format color to avoid errors with quote marks

        sql_request = f"INSERT INTO metadata VALUES ('{filename}', 'dominant_color', '{color}')"
        # Add SQL request to list
        sql_requests.append(sql_request)

    # Return the list of SQL requests
    return sql_requests

This method executes a SQL query on a SQLite database. The method takes a single parameter, query, which should be a valid SQL query that is compatible with the SQLite database.

The method first establishes a connection to the SQLite database using the connect() method of the sqlite3 module in Python. It then executes the SQL query using the execute() method of the connection object. After executing the query, the method commits the changes to the database using the commit() method of the connection object, and then closes the connection using the close() method of the connection object.

This method is typically used to insert metadata information into a SQLite database. It assumes that the SQLite database already exists and is located in the metadata_path directory.

The following code block is used to process images and find their dominant colors. The code first retrieves all the images present in the folder specified by the images_path variable. Then, it iterates over each image, reads the metadata associated with the image and finds its dominant color if it hasn't been calculated already.

For each image, the code first reads the image using OpenCV's cv2.imread() function and stores the result in the img variable. The code then reads the metadata of the image. The type of metadata file (e.g. .json, .pkl, .sqlite) is specified by the metadata_extension variable. Based on the file extension, the code reads the metadata using either read_sqlite(), json.load(), or pickle.load() functions. If the metadata file is not found, the code continues to the next iteration of the loop, but if there is an error, it prints the error message and continues to the next iteration.

If the metadata does not contain information about the dominant color of the image, the code calculates the dominant color by calling the find_dominant_colors() function. The result of the find_dominant_colors() function is then added to the metadata under the key "dominant_color". Finally, the updated metadata is saved using the save_metadata() function, which saves the metadata to the specified location using the specified file format (metadata_extension).

In [None]:
import numpy as np


def get_all_colors(image_path):
    """
    This coroutine extracts dominant colors from all images in a directory and saves the color information in the database.

    Parameters:
    image_path (str): The path to the directory where the images are stored.

    Returns:
    None
    """
    # Get a list of all images in the directory
    img_files = get_all_images(image_path)
    colors = []

    # Create a progress bar to track the progress of processing all images
    for img in tqdm(img_files, desc="Processing images (Aprox: 25 minutes"):
        try:
            # Create a list of coroutines to extract metadata for all images
            color = find_dominant_colors(img, downsample=2, resize=(100, 100))
        except Exception as e:
            print("Error: ", e)
            continue

        if color:
            # color to string to avoid errors with quote marks
            color = str(color)
            # replace quotes by double quotes
            color = color.replace("'", '"')
            colors.append(color)

    img_files = [os.path.basename(img) for img in img_files]

    queries = gen_sql_requests(img_files, colors)

    # save queries in a file colors_queries.sql
    with open('colors_queries.sql', 'w') as f:
        f.write(";\n".join(queries))

    # TODO: Save to database
    conn = create_server_connection(sql_host, sql_user, sql_password, sql_database)
    # Execute file to database
    cursor = conn.cursor()
    #cursor.execute("DROP TABLE IF EXISTS metadata")
    #cursor.execute("CREATE TABLE metadata (filename VARCHAR(255), key VARCHAR(255), value VARCHAR(255))")
    cursor.execute("LOAD DATA LOCAL INFILE '" + os.path.join(os.getcwd(),
                                                             'colors_queries.sql') + "' INTO TABLE metadata FIELDS TERMINATED BY ';' LINES TERMINATED BY ';'")
    conn.commit()
    conn.close()

In [None]:
get_all_colors(images_path)

# Notebook 3

# Imports

The code above imports various Python modules and libraries for data processing, visualization, and analysis. Below is a description of each module and library imported:

- `os`: Provides a way to interact with the operating system, such as navigating directories and working with files.
- `ast`: Provides a way to parse Python code into an abstract syntax tree, which can be used to analyze and manipulate the code.
- `spacy`: A library for natural language processing, including tasks such as tokenization, part-of-speech tagging, and named entity recognition.
- `folium`: A library for creating interactive maps using the Leaflet JavaScript library.
- `sqlite3`: A module for working with SQLite databases.
- `squarify`: A library for generating treemaps, which visualize hierarchical data using nested rectangles.
- `itertools`: Provides a collection of functions for working with iterators, such as combining multiple iterators or creating permutations.
- `webcolors`: A library for working with CSS-style color strings.
- `tqdm`: A library for creating progress bars for loops.
- `pandas`: A library for data manipulation and analysis, including reading and writing data to/from various file formats.
- `ipywidgets`: Provides interactive widgets for Jupyter notebooks and other IPython environments.
- `matplotlib.pyplot`: A library for creating visualizations, including line plots, scatter plots, bar charts, and histograms.
- `collections.Counter`: A container that keeps track of the frequency of elements in a collection.

In [None]:
import os
import ast
import spacy
import folium
import sqlite3
import datetime
import squarify
import itertools
import webcolors
import pandas as pd
from tqdm import tqdm
import ipywidgets as widgets
from collections import Counter
import matplotlib.pyplot as plt
from geopy.geocoders import Nominatim
from dotenv import load_dotenv
load_dotenv()

# Settings base variables and paths

The code above sets the base folder path and creates four subdirectories within that base path: "images", "metadata", and "config". These subdirectories are created by joining the base path with their respective names using the os.path.join() function.

In [None]:
# Set the base folder path for the project
output_path = "../output"
images_path = os.path.join(output_path, "images")
metadata_path = os.path.join(output_path, "metadata")
config_path = os.path.join(output_path, "config")

list_of_paths = [output_path, images_path, metadata_path, config_path]
whitelist = ['Make', 'DateTimeOriginal', 'ImageWidth', 'ImageHeight', 'filename', 'Artist', 'Latitude', 'Longitude', 'Orientation', 'tags']

# Set SQL variables
sql_host = os.getenv("SQL_HOST")
sql_user = os.getenv("SQL_USER")
sql_password = os.getenv("SQL_PASSWORD")
sql_database = os.getenv("SQL_DATABASE")

# Get the metadata from db and sort it with list into a dictionary

### Get the metadata from the database

The function "get_metadata_from_sqlite_DB" is used to retrieve metadata from a SQLite database. It takes an optional argument "db_name" to specify the name of the database. It opens a connection to the database, creates a cursor, and retrieves metadata for the first file in the images' directory. The metadata is stored in a dictionary format, where the filename is the key and the metadata is a list of key-value pairs. The function then closes the connection and returns the dictionary of metadata.

In [None]:
def get_metadata_from_sqlite_db(db_name='metadata.db') :
    """
    Get the metadata from the sqlite database

    :param db_name: The name of the database
    :return: A dictionary with the metadata
    """
    # Open a connection to the database
    conn = sqlite3.connect(os.path.join(metadata_path, db_name))
    # Create a cursor
    c = conn.cursor()

    # Retrieve the metadata
    c.execute("""
        SELECT filename, GROUP_CONCAT(key || '\t' || value, '\n') AS metadata
        FROM metadata
        GROUP BY filename;
    """)
    metadata = c.fetchall()

    # Close the connection
    conn.close()

    # Convert the metadata to a DataFrame
    result = {}
    for image in tqdm(metadata, desc="Get metadata from database"):
        try :
            result[image[0]] = {}
            props = image[1].split('\n')
            for prop in props:
                if prop:
                    k, value = prop.split('\t')
                    if k in whitelist:
                        result[image[0]][k] = value
        except Exception as e:
            print(e, image)

    return result

In [None]:
import os
import mysql.connector
from tqdm import tqdm

def get_metadata_from_mariadb_db(db_name='bigdata', user='root', password='', host='localhost', port='3306'):
    """
    Get the metadata from the MariaDB database

    :param db_name: The name of the database
    :param user: The username to connect to the database
    :param password: The password to connect to the database
    :param host: The hostname or IP address of the database server
    :param port: The port number to connect to the database server
    :return: A dictionary with the metadata
    """
    # Open a connection to the database
    conn = mysql.connector.connect(
        user=user,
        password=password,
        host=host,
        port=port,
        database=db_name
    )
    # Create a cursor
    c = conn.cursor()

    # Retrieve the metadata
    c.execute("""
        SELECT filename, GROUP_CONCAT(CONCAT(mkey, '\t', mvalue) SEPARATOR '\n') AS metadata
        FROM metadata
        GROUP BY filename;
    """)
    metadata = c.fetchall()

    # Close the connection
    conn.close()

    # Convert the metadata to a dictionary
    result = {}
    for image in tqdm(metadata, desc="Get metadata from database"):
        try:
            result[image[0]] = {}
            props = image[1].split('\n')
            for prop in props:
                if prop:
                    k, value = prop.split('\t')
                    if k in whitelist:
                        result[image[0]][k] = value
        except Exception as e:
            print(e, image)

    return result

get_metadata_from_mariadb_db(sql_database, sql_user, sql_password, sql_host)

### Clean the metadata

The function "clean_metadata" is used to clean the metadata. It takes a dictionary of metadata as an argument and returns a dictionary with the cleaned metadata. The function removes special characters from the 'Make' property values and removes the 'T' and '-' characters from the 'DateTime' property values.

In [None]:
def clean_metadata(metadata_to_clean):
    """
    Clean the metadata
    Remove special characters from the 'Make' property values
    Remove the 'T' and '-' characters from the 'DateTime' property values

    :param metadata_to_clean: The metadata to clean
    :return: A dictionary with the cleaned metadata
    """
    cln_meta = metadata_to_clean.copy()

    # Clean 'Make' property values
    try:

        for file in tqdm(cln_meta, desc="Clean 'Make' property values"):
            if 'Make' in cln_meta[file]:
                cln_meta[file]['Make'] = ''.join(filter(str.isalpha, cln_meta[file]['Make'])).replace('CORPORATION', '').replace('CORP', '').replace('COMPANY', '').replace('LTD', '').replace('IMAGING', '')
    except Exception as e:
        print(e, file)

    try:
        # Clean 'DateTime' property values
        cpt, cpt_error = 0, 0
        date_error = []

        for file in tqdm(cln_meta, desc="Clean 'DateTime' property values"):
            if 'DateTimeOriginal' in cln_meta[file]:
                date = cln_meta[file]['DateTimeOriginal']
                try :
                    if date is not None:
                        tmp = date.replace('T', ' ').replace('-', ':').split('+')[0]
                        cln_meta[file]['DateTimeOriginal'] = datetime.datetime.strptime(tmp[:19], '%Y:%m:%d %H:%M:%S')
                        # if the year is after actual year, we assume that the date is wrong
                        if cln_meta[file]['DateTimeOriginal'].year > datetime.datetime.now().year:
                            date_error.append(cln_meta[file]['DateTimeOriginal'])
                            cln_meta[file]['DateTimeOriginal'] = None
                            cpt_error += 1
                        else:
                            cpt += 1
                except ValueError:
                    date_error.append(date)
                    cln_meta[file]['DateTimeOriginal'] = None
                    cpt_error += 1
    except Exception as e:
        print(e, file)

    print(f"Metadata cleaned ! {cpt}/{len(cln_meta)} dates OK, {cpt_error} dates KO")
    print(f"Dates KO : {date_error}")

    # Clean 'tags' property values
    for file in tqdm(cln_meta, desc="Clean 'tags' property values"):
        if 'tags' in cln_meta[file]:
            if cln_meta[file]['tags'] is not None:
                val = eval(cln_meta[file]['tags'])
            cln_meta[file]['tags'] = val

    return cln_meta

In [None]:
# Get the metadata from the SQLite database
brut_metadata = get_metadata_from_mariadb_db(sql_database, sql_user, sql_password, sql_host)
# Clean the metadata
cln_metadata = clean_metadata(brut_metadata)
# Convert the metadata to a DataFrame
df_metadata = pd.DataFrame.from_dict(cln_metadata).transpose()
df_metadata['Make'].fillna('Undefined', inplace=True)
df_metadata['GPSInfo'].fillna('Undefined', inplace=True)
df_metadata.head()

### Overview of the metadata

The function "count_data_per_property" is used to count the number of non-null values for each property in the metadata dictionary. It takes a dictionary of metadata as an argument and prints the properties that have more than 70 non-null values.

The function "metadata_extract_example" is used to print the first 3 elements of each list in the dict_metadata dictionary. It takes a dictionary of metadata as an argument and prints the first 3 elements of each list in the dict_metadata dictionary.

In [None]:
def count_data_per_property(metadata_to_count, significant_limit=70):
    """
    Count the number of non-null values for each property in the metadata dictionary
    Display the properties that have significant non-null values

    :param metadata_to_count: The metadata to count
    :param significant_limit: The limit after which a property is considered significant
    """
    # Count the number of non-null values for each property in the metadata dictionary
    prop_len = {}
    for prop in metadata_to_count:
        prop_len[prop] = metadata_to_count[prop].count()

        # Print the properties that have more than 70 non-null values
        if prop_len[prop] > significant_limit:
            print(f'{prop} : {prop_len[prop]}')

In [None]:
print(f'Number of images : {len(df_metadata["File Name"])}')
print("-------------- Properties with more significant non-null values --------------")
count_data_per_property(df_metadata)

# Define the functions to display the metadata

- "display_bar" is used to display a bar chart.
- "display_pie" is used to display a pie chart.
- "display_curve" is used to display a curve chart.

In [None]:
def display_bar(title, x_label, y_label, x_values, y_values):
    """
    Display a bar chart

    :param title: The title of the chart
    :param x_label: The x-axis label
    :param y_label: The y-axis label
    :param x_values: The values of the x-axis
    :param y_values: The values of the y-axis
    """
    plt.bar(x_values, y_values)
    plt.title(title)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.xticks(rotation=90)
    plt.show()

In [None]:
def display_pie(title, values, labels):
    """
    Display a pie chart

    :param title: The title of the chart
    :param values: The values of the chart
    :param labels: The labels of the chart
    """
    plt.pie(values, labels=labels, autopct='%1.1f%%')
    plt.title(title)
    plt.show()

In [None]:
def display_curve(title, x_label, y_label, x_values, y_values):
    """
    Display a curve

    :param title: The title of the curve
    :param x_label: The label of the x_axis
    :param y_label: The label of the y_axis
    :param x_values: The values of the x_axis
    :param y_values: The values of the y_axis
    """

    plt.plot(x_values, y_values)
    plt.xticks(rotation=90)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.show()

# Graph images : size (static)

The function "graph_images_size_static" is used to graph the number of images per size category. It takes a dictionary of metadata as an argument and returns a graph with the number of images per size category. The interval size is 200 by default.


In [None]:
def graph_images_size_static(df_meta, interval_size=200, nb_intervals=4):
    """
    Graph the number of images per size category
    The interval size is 200 by default

    :param df_meta: The metadata to graph
    :param interval_size: The size of the intervals
    :param nb_intervals: The number of intervals
    """

    # Calculate the minimum size of each image and store it in a new column
    df_meta['min_size'] = df_meta[['Width', 'Height']].min(axis=1)

    # Determine the maximum minimum size
    max_min_size = df_meta['min_size'].max()

    # Create a list of intervals based on the interval size and number of intervals
    inter = [i * interval_size for i in range(nb_intervals + 1)]

    # Create a list of labels for each interval
    labels = [f'{inter[i]}-{inter[i + 1]}' for i in range(nb_intervals)]

    # Categorize each image based on its size and interval
    df_meta['size_category'] = pd.cut(df_meta['min_size'], bins=inter, labels=labels)

    # Count the number of images in each category
    size_counts = df_meta['size_category'].value_counts()

    display_bar('Number of images per size category', 'Size category', 'Number of images', size_counts.index, size_counts.values)

In [None]:
graph_images_size_static(df_metadata, 2000, 3)

# Graph images : size (dynamic)

The function "graph_images_size_dynamic" is used to graph the number of images per size category. It takes a dictionary of metadata as an argument and returns a graph with the number of images per size category. The interval size is calculated dynamically. The number of columns in the graph is 7 by default.

You can choose the type of graph to display (bar, pie or all for both).

In [None]:
def graph_images_size_dynamic(df_meta, nb_intervals=7, graph_type='all'):
    """
    Graph the number of images per size category
    The interval size is calculated dynamically

    :param df_meta: The metadata to graph
    :param nb_intervals: The number of intervals in the graph
    :param graph_type: The type of graph to display (bar, pie or all for both)
    """

    # Calculate the minimum size of each image and store it in a new column
    df_meta['min_size'] = df_meta[['Height', 'Width']].min(axis=1)

    # Determine the maximum minimum size and calculate the number of bins dynamically based on the number of columns
    max_min_size = df_meta['min_size'].max()
    num_images = len(df_meta)
    num_bins = int(num_images / (num_images / nb_intervals))

    # Create a list of bins based on the maximum minimum size and number of bins
    bins = [i * (max_min_size / num_bins) for i in range(num_bins + 1)]

    # Create a list of labels for each bin
    labels = [f'{int(bins[i])}-{int(bins[i + 1])}' for i in range(num_bins)]

    # Categorize each image based on its size and bin
    df_meta['size_category'] = pd.cut(df_meta['min_size'], bins=bins, labels=labels)

    # Count the number of images in each category
    size_counts = df_meta['size_category'].value_counts()

    title = 'Number of images per size category'

    # Create the appropriate chart based on the graph type parameter
    if graph_type == 'bar':
        display_bar(title, 'Image size', 'Number of images', size_counts.index, size_counts.values)
    elif graph_type == 'pie':
        display_pie(title, size_counts.values, size_counts.index)
    elif graph_type == 'all':
        display_bar(title, 'Image size', 'Number of images', size_counts.index, size_counts.values)
        display_pie(title, size_counts.values, size_counts.index)

    else:
        raise ValueError('Invalid graph type')


In [None]:
graph_images_size_dynamic(df_metadata, 5, 'all')

# Graph images : DateTime

The function "graph_images_datetime" is used to graph the number of images per year. It takes a dictionary of metadata as an argument and returns a graph with the number of images per year.

You can choose the type of graph to display (bar, pie, curve or all for all).


In [None]:
def graph_images_datetime(df_meta, nb_intervals=10, graph_type='all'):
    """
    Graph the number of images per year

    :param df_meta: The metadata to graph (expects a list of dictionaries)
    :param graph_type: The type of graph to display (bar, pie, curve or all for all)
    :param nb_intervals: The number of intervals to display
    """

    # Extract year from the 'DateTime' column and create a new 'Year' column
    df_meta['Year'] = pd.DatetimeIndex(df_meta['DateTimeOriginal']).year


    # Group the data by year and count the number of images for each year
    image_count = df_meta.groupby('Year').size().reset_index(name='count').sort_values('count', ascending=False)[
                  :nb_intervals]
    image_count['Year'] = image_count['Year'].astype(int)

    # Set the title of the graph
    title = 'Number of images per year'

    # Display different types of graphs based on the 'graph_type' parameter
    if graph_type == 'bar':
        # Display a bar chart
        image_count.plot(kind='bar', x='Year', y='count')
        display_bar(title, 'Year', 'Number of images', image_count['Year'], image_count['count'])

    elif graph_type == 'pie':
        # Display a pie chart using a custom function 'display_pie'
        display_pie(title, image_count['count'], image_count['Year'])

    elif graph_type == 'curve':
        # Display a line chart using a custom function 'display_curve'
        image_count = df_meta.groupby('Year').size().reset_index(name='count').sort_values('Year', ascending=True)
        display_curve(title, 'Year', 'Number of images', image_count['Year'], image_count['count'])

    elif graph_type == 'all':
        # Display all three types of graphs: bar, pie, and line charts

        # Bar chart
        image_count.plot(kind='bar', x='Year', y='count')
        display_bar(title, 'Year', 'Number of images', image_count['Year'], image_count['count'])

        # Pie chart
        display_pie(title, image_count['count'], image_count['Year'])

        # Line chart
        image_count = image_count.sort_values('Year', ascending=True)
        display_curve(title, 'Year', 'Number of images', image_count['Year'], image_count['count'])
    else:
        # Raise an error if an invalid 'graph_type' parameter is passed
        raise ValueError('Invalid graph type')

In [None]:
graph_images_datetime(df_metadata, 10, 'all')

# Graph images : Brand

The function "graph_images_brand" is used to graph the number of images per brand. It takes a dictionary of metadata as an argument and returns a graph with the number of images per brand.

You can choose the type of graph to display (bar, pie or all for both), and the number of columns to display.

In [None]:
def graph_images_brand(df_meta, graph_type='all', nb_columns=5):
    """
    Graph the number of images per brand

    :param df_meta: The metadata to graph
    :param graph_type: The type of graph to display (bar, pie or all for both)
    :param nb_columns: The number of columns to display
    """

    # Initialize an empty dictionary to store the counts of each brand
    counts = {}

    # Loop through each brand in the metadata and count the number of occurrences
    for make in df_meta['Make']:
        if make is not None :
            counts[make] = counts.get(make, 0) + 1

    sorted_counts = dict(sorted(counts.items(), key=lambda x: x[1], reverse=True))

    # Convert the dictionary into two lists of labels and values for graphing
    labels = list(sorted_counts.keys())[:nb_columns]
    values = list(sorted_counts.values())[:nb_columns]

    # Set the title for the graph
    title = 'Number of images per brand'

    # Determine which type of graph to display based on the 'graph_type' parameter
    if graph_type == 'bar':
        # Display a bar graph
        display_bar(title, 'Brand', 'Number of images', labels, values)
    elif graph_type == 'pie':
        # Display a pie chart
        display_pie(title, values, labels)
    elif graph_type == 'all':
        # Display both a bar graph and a pie chart
        display_bar(title, 'Brand', 'Number of images', labels, values)
        display_pie(title, values, labels)
    else:
        # Raise an error if the 'graph_type' parameter is invalid
        raise ValueError('Invalid graph type')

In [None]:
graph_images_brand(df_metadata, 'all', 10)

# Graph images : Images with GPS

#### Overview

The function "gps_info_overview" is used to display the number of images with GPS data. It takes a dictionary of metadata as an argument and returns the number of images with GPS data.

In [None]:
def gps_info_overview(df_meta):
    """
    Display the number of images with GPS data
    """
    cpt = 0
    # get images with GPS data and print it name and the GPS data
    for idx, meta in enumerate(df_meta['GPSInfo']):
        if meta is not None and len(meta) > 24:
            # print(dict_metadata['file'][idx])
            # print(meta)
            cpt += 1
    print(f"Number of images with GPS data : {cpt}")

In [None]:
gps_info_overview(df_metadata)

### Extract GPS coordinates

The function "get_coordinates" is used to extract the coordinates of the images with GPS data. It takes a dictionary of metadata as an argument and returns a dictionary with the coordinates of the images with GPS data.

It uses the function "dms_to_dd" to convert the coordinates from DMS (degrees, minutes, seconds) to DD (decimal degrees).

In [None]:
def dms_to_dd(degrees, minutes, seconds, direction):
    """
    Convert DMS (degrees, minutes, seconds) coordinates to DD (decimal degrees)

    :param degrees: degrees
    :param minutes: minutes
    :param seconds: seconds
    :param direction: direction (N, S, E, W)
    :return: decimal degrees
    """

    dd = float(degrees) + float(minutes) / 60 + float(seconds) / (60 * 60)
    if direction == 'S' or direction == 'W':
        dd *= -1
    return dd

In [None]:
def get_coordinates(metadata):
    """
    Extract the coordinates of the images with GPS data

    :param metadata: The metadata to extract the coordinates from
    """
    coordinates = {}
    for i, gps_info in enumerate(metadata['GPSInfo']):
        if gps_info is not None:
            try:
                gps_info = eval(gps_info)
                latitude, longitude = None, None
                for key, val in gps_info.items():
                    if val == 'N' or val == 'S':
                        nxt = gps_info[key + 1]
                        latitude = dms_to_dd(nxt[0], nxt[1], nxt[2], val)
                    elif val == 'E' or val == 'W':
                        nxt = gps_info[key + 1]
                        longitude = dms_to_dd(nxt[0], nxt[1], nxt[2], val)

                if latitude is not None and longitude is not None:
                    coordinates.update({dict_metadata['file'][i]: [latitude, longitude]})
            except:
                print(f"Error with {dict_metadata['file'][i]}")
                # print(gps_info)
    print(f"Number of images with valid GPS data : {len(coordinates)}")

    return coordinates

In [None]:
coordinates = get_coordinates(dict_metadata)

### Using a Map with markers

The function "display_coordinates_on_map" is used to display the coordinates of the images with GPS data on a map. It takes a dictionary of coordinates as an argument and returns a map with the coordinates displayed as markers.

Inside the method comments, you can find a way to display the images as markers instead of the default markers.

In [None]:
def display_coordinates_on_map(coordinates_list):
    """
    Display the coordinates on a map

    :param coordinates_list: The coordinates to display
    :return: The map with the coordinates displayed as markers
    """

    # create a map centered at a specific location
    m = folium.Map(location=[0, 0], zoom_start=1)

    # add markers for each set of coordinates
    for image, coords in coordinates_list.items():
        lat, lon = coords

        # Create a marker with the image as the icon
        # !warning! : the image must be download and you need to add :
        # from folium.features import CustomIcon

        # image_path = '../output/images/' + key
        # icon = CustomIcon(icon_image=image_path, icon_size=(100, 100))
        # folium.Marker(location=coord, icon=icon).add_to(m)
        folium.Marker(location=[lat, lon], tooltip=image, popup=f'file:{image}\ncoord:{coords}').add_to(m)
    return m

In [None]:
display_coordinates_on_map(coordinates)

### Using graphs by country

#### get country

The function "get_country" is used to get the country of each coordinate. It takes a dictionary of coordinates as an argument and returns a dictionary with the coordinates and the country.

In [None]:
def get_country(coordinates_list):
    """
    Get the country of each coordinate

    :param coordinates_list: The coordinates to get the country from
    :return: The coordinates with the country added
    """
    # Create a geolocator
    geolocator = Nominatim(user_agent="geoapiExercises")

    # Get the continent information for each coordinate
    for key, coord in tqdm(coordinates_list.items(), desc='Getting country information'):
        if len(coord) < 3:  # If the country hasn't been found yet
            try:
                location = geolocator.reverse(coord, exactly_one=True, language='en')
                address = location.raw['address']
                country = address.get('country')
                coordinates[key].append(country)
            except:
                print(f"Error with {key} : {coord}")

In [None]:
get_country(coordinates)

### Display graphs

The function "graph_images_countries" is used to display graphs about the number of images by country. It takes a dictionary of coordinates as an argument and returns a graph.

The parameter "nb_inter" is used to set the number of interval to display. The parameter "graph" is used to set the type of graph to display (bar, pie, all).

In [None]:
def graph_images_countries(coord_list, nb_inter=5, graph='all'):
    """
    Display graphs about the number of images by country

    :param coord_list: list of coordinates
    :param nb_inter: number of interval
    :param graph: type of graph to display (bar, pie, all)
    """

    # Create a pandas DataFrame from the coordinates dictionary
    df = pd.DataFrame.from_dict(coord_list, orient='index', columns=['Latitude', 'Longitude', 'Country'])

    # Group the DataFrame by continent and count the number of images
    country_count = df.groupby('Country')['Country'].count()
    country_count = country_count.sort_values(ascending=False)[:nb_inter]

    title = 'Number of images by country'

    if graph == 'bar':
        display_bar(title, "Country", "Image Count", country_count.index, country_count.values)
    elif graph == 'pie':
        display_pie(title, country_count.values, country_count.index)
    else:
        display_bar(title, "Country", "Image Count", country_count.index, country_count.values)
        display_pie(title, country_count.values, country_count.index)

In [None]:
graph_images_countries(coordinates, 10, 'all')

# Graph images : by Dominant Color

In [None]:
MAX_COLUMNS = 20


def closest_colour(requested_colour):
    min_colours = {}
    for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
        r_c, g_c, b_c = webcolors.hex_to_rgb(key)
        rd = (r_c - requested_colour[0]) ** 2
        gd = (g_c - requested_colour[1]) ** 2
        bd = (b_c - requested_colour[2]) ** 2
        min_colours[(rd + gd + bd)] = name
    return min_colours[min(min_colours.keys())]


def get_colour_name(requested_colour):
    try:
        closest_name = actual_name = webcolors.rgb_to_name(requested_colour)
    except ValueError:
        closest_name = closest_colour(requested_colour)
        actual_name = None
    return actual_name, closest_name


dict_dom_color = {}
for idx, dom_color in enumerate(dict_metadata['dominant_color']):
    if dom_color is not None:
        list_dom_color = eval(dom_color)
        dict_dom_color.update({dict_metadata['file'][idx]: list_dom_color})

color_counts = Counter()
for image_colors in dict_dom_color.values():
    for color, percentage in image_colors:
        color_counts[color] += percentage

# Map hexadecimal codes to color names
color_names = {}
for code in color_counts.keys():
    try:
        rgb = webcolors.hex_to_rgb(code)
        actual, closest = get_colour_name(rgb)
        color_names[code] = closest
    except ValueError:
        pass

dict_res = {}
for key, val in color_names.items():
    if val in dict_res:
        dict_res[val] += round(color_counts[key] / 100, 5)
    else:
        dict_res[val] = round(color_counts[key] / 100, 5)

if sum(dict_res.values()) > 100:
    raise Exception('sum of dict_res.values() > 100')

columns = dict_res.__len__()
if columns > MAX_COLUMNS:
    columns = MAX_COLUMNS

sorted_colors = sorted(dict_res.items(), key=lambda x: x[1], reverse=True)
top_colors = dict(sorted_colors[:columns])

# Create a bar graph showing the dominant colors in the images
plt.bar(top_colors.keys(), top_colors.values(), color=top_colors.keys())
plt.xticks(rotation=90)
plt.show()

# Create a pie chart showing the dominant colors in the images
fig, ax = plt.subplots(figsize=(8, 8))
colors = list(top_colors.keys())
ax.pie(top_colors.values(), labels=top_colors.keys(), autopct='%1.1f%%', colors=colors, textprops={'color': 'white'})
ax.set_title('Top Colors')
ax.legend(title='Colors', loc='center right', bbox_to_anchor=(1.2, 0.5))
plt.show()

# Create a treemap showing the dominant colors in the images
color = [webcolors.name_to_hex(c) for c in top_colors]
labels = list(top_colors.keys())
sizes = list(top_colors.values())
squarify.plot(sizes=sizes, label=labels, color=color, alpha=.7)
plt.title("Top Colors")
plt.axis('off')
plt.show()

# Graph images : by Tags

In [None]:
# convert tag strings to a list of tags
# convert tag strings to a list of tags
#tags = list(itertools.chain.from_iterable([ast.literal_eval(t) for t in dict_metadata['tags']]))

tags = dict_metadata['tags']
# flatten the list of tags
tags = [item for sublist in tags for item in sublist]

# count the occurrences of each tag
tag_counts = Counter(tags)

# plot the most common tags
n = 10
top_tags = dict(tag_counts.most_common(n))
plt.bar(top_tags.keys(), top_tags.values())
plt.title(f"Top {n} most common tags")
plt.xlabel("Tags")
plt.ylabel("Frequency")
plt.xticks(rotation=90)
plt.show()


In [None]:
tags = dict_metadata['tags']
# flatten the list of tags
tags = [item for sublist in tags for item in sublist]

try:
    nlp = spacy.load("en_core_web_lg")  # load pre-trained word embedding model
except OSError:
    !python -m spacy download en_core_web_lg

categories = {
    "landscape": {}, "animal": {}, "people": {}, "food": {}, "building": {}, "vehicle": {}, "object": {}, "other": {}
}

# categorize words based on similarity to category prototypes
for word in tags:
    # find the most similar category prototype for the word
    max_similarity = -1
    chosen_category = "other"
    for category in categories:
        similarity = nlp(word).similarity(nlp(category))
        if similarity > max_similarity:
            max_similarity = similarity
            chosen_category = category

    # add the word into the appropriate category dictionary
    categories[chosen_category].update({word: max_similarity})

print(categories)

In [None]:
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
import numpy as np


# Convert the dictionary into a numpy array
# Convert the dictionary into a numpy array
def dict_to_array(categories):
    n_categories = len(categories)
    arr = np.zeros((n_categories, n_categories))
    for i, (cat, subcat) in enumerate(categories.items()):
        for j, (subcat, val) in enumerate(subcat.items()):
            if j < n_categories:
                arr[i, j] = val
    return arr


# Generate the linkage matrix
Z = linkage(dict_to_array(categories), 'ward')

# Plot the dendrogram
fig = plt.figure(figsize=(10, 5))
dn = dendrogram(Z, labels=list(categories.keys()))
plt.show()

In [None]:
tags = dict_metadata['tags']
# flatten the list of tags
tags = [item for sublist in tags for item in sublist]
user_data = {}

# Create a label for the title
title_label = widgets.Label(value='User Information Form')

# Create text boxes for first and last name
pseudo = widgets.Text(description='Pseudo :')

# Create a color picker for favorite colors
color_picker = widgets.ColorPicker(
    concise=True,
    description='Favorite Colors:',
    value='#FF0000',
    continuous_update=False,
    disabled=False
)

# Create a dropdown list of tags
tag_dropdown = widgets.SelectMultiple(
    options=tags,
    value=[],
    description='Tags:',
    disabled=False
)

# Create a dropdown for image orientation
orientation_dropdown = widgets.Dropdown(
    options=['Portrait', 'Landscape'],
    value='Portrait',
    description='Orientation:'
)

# Create sliders for image height and width
height_slider = widgets.IntSlider(min=100, max=4000, step=100, description='Height:')
width_slider = widgets.IntSlider(min=100, max=4000, step=100, description='Width:')

#  Create a button to submit the form
submit_button = widgets.Button(description='Submit')

# Create a VBox container for the widgets
form_container = widgets.VBox([
    title_label,
    pseudo,
    color_picker,
    orientation_dropdown,
    height_slider,
    width_slider,
    tag_dropdown,
    submit_button
])

form_container.layout = widgets.Layout(
    width='600px',
    height='500px',
    justify_content='center',  # Centrer les widgets horizontalement
    align_items='center'  # Centrer les widgets verticalement
)


# Define a function to handle form submission
def on_submit_button_clicked(b):
    user_data.update({
        pseudo.value: {
            'fav_color': color_picker.value,
            'fav_orientation': orientation_dropdown.value,
            'fav_height': height_slider.value,
            'fav_width': width_slider.value,
            'tags': tag_dropdown.value
        }
    }
    )


# Attach the on_submit_button_clicked function to the button click event
clicked = submit_button.on_click(on_submit_button_clicked)

# Display the form container
display(form_container)

In [None]:
print(user_data['Yannis']['tags'])

In [None]:
list_columns = ['fav_color', 'fav_orientation', 'fav_height', 'fav_width', 'tags']


def save_metadata(user_data):
    try:
        # Open a connection to the database
        conn = sqlite3.connect(os.path.join(metadata_path, 'metadata.db'))
        # Create a cursor
        c = conn.cursor()
        # Create a table if it doesn't exist : filename, key, value
        c.execute(
            '''CREATE TABLE IF NOT EXISTS users (
            pseudo text PRIMARY KEY,
            fav_color text,
            fav_orientation text,
            fav_height integer,
            fav_width integer,
            fav_tags text
        )''')

        nb_users = len(user_data.keys())

        check = True
        # check if all data are usable
        if nb_users >= 1:
            for pseudo in user_data.keys():
                for column in list_columns:
                    if user_data[pseudo][column] is None:
                        check = False
                        return
        else:
            check = False

        if check:
            for pseudo in user_data.keys():
                c.execute("SELECT * FROM users WHERE pseudo=?", (pseudo,))
                if c.fetchone():
                    c.execute(
                        "UPDATE users SET fav_color=?, fav_orientation=?, fav_height=?, fav_width=?, tags = ? WHERE pseudo=?",
                        (user_data[pseudo]['fav_color'], user_data[pseudo]['fav_orientation'],
                         user_data[pseudo]['fav_height'], user_data[pseudo]['fav_width'], str(user_data[pseudo]['tags']),
                         pseudo))
                    conn.commit()
                else:
                    c.execute("INSERT INTO users VALUES (?, ?, ?, ?, ?, ?)", (
                    pseudo, user_data[pseudo]['fav_color'], user_data[pseudo]['fav_orientation'],
                    user_data[pseudo]['fav_height'], user_data[pseudo]['fav_width'], str(user_data[pseudo]['tags'])))
                    conn.commit()
                print(f"User {pseudo} saved to database successfully")
            conn.close()
        else:
            print("Invalid User data")

    except Exception as e:
        print("Error while saving user data to database" + str(e))


save_metadata(user_data)
print(user_data)

# Notebook 4

In [None]:
import os
import ast
import sqlite3
import itertools
import ipywidgets as widgets
from dotenv import load_dotenv
load_dotenv()

In [None]:
# Set the base folder path for the project
output_path = "../output"
images_path = os.path.join(output_path, "images")
metadata_path = os.path.join(output_path, "metadata")
config_path = os.path.join(output_path, "config")

list_of_paths = [output_path, images_path, metadata_path, config_path]

# Set SQL variables
sql_host = os.getenv("SQL_HOST")
sql_user = os.getenv("SQL_USER")
sql_password = os.getenv("SQL_PASSWORD")
sql_database = os.getenv("SQL_DATABASE")

In [None]:
dict_metadata = {}


tags = set(list(itertools.chain.from_iterable([ast.literal_eval(t) for t in dict_metadata['tags']])))
user_data = {}

# Create a label for the title
title_label = widgets.Label(value='User Information Form')

# Create text boxes for first and last name
pseudo = widgets.Text(description='Pseudo :')

# Create a color picker for favorite colors
color_picker = widgets.ColorPicker(
    concise=True,
    description='Favorite Colors:',
    value='#FF0000',
    continuous_update=False,
    disabled=False
)

# Create a dropdown list of tags
tag_dropdown = widgets.SelectMultiple(
    options=tags,
    value=[],
    description='Tags:',
    disabled=False
)

# Create a dropdown for image orientation
orientation_dropdown = widgets.Dropdown(
    options=['Portrait', 'Landscape'],
    value='Portrait',
    description='Orientation:'
)

# Create sliders for image height and width
height_slider = widgets.IntSlider(min=100, max=4000, step=100, description='Height:')
width_slider = widgets.IntSlider(min=100, max=4000, step=100, description='Width:')

#  Create a button to submit the form
submit_button = widgets.Button(description='Submit')

# Create a VBox container for the widgets
form_container = widgets.VBox([
    title_label,
    pseudo,
    color_picker,
    orientation_dropdown,
    height_slider,
    width_slider,
    tag_dropdown,
    submit_button
])

form_container.layout = widgets.Layout(
    width='600px',
    height='500px',
    justify_content='center',  # Centrer les widgets horizontalement
    align_items='center'  # Centrer les widgets verticalement
)

# Define a function to handle form submission
def on_submit_button_clicked(b):
    user_data.update({
        pseudo.value : {
            'fav_color' : color_picker.value,
            'fav_orientation' : orientation_dropdown.value,
            'fav_height' : height_slider.value,
            'fav_width' : width_slider.value,
            'tags': tag_dropdown.value
        }
    }
    )

# Attach the on_submit_button_clicked function to the button click event
clicked = submit_button.on_click(on_submit_button_clicked)

# Display the form container
display(form_container)

In [None]:
print(user_data)

In [None]:
list_columns = ['fav_color', 'fav_orientation', 'fav_height', 'fav_width']

def save_metadata(user_data):
    try:
        # Open a connection to the database
        conn = sqlite3.connect(os.path.join(metadata_path, 'metadata.db'))
        # Create a cursor
        c = conn.cursor()
        # Create a table if it doesn't exist : filename, key, value
        c.execute(
            '''CREATE TABLE IF NOT EXISTS users (
            pseudo text PRIMARY KEY,
            fav_color text,
            fav_orientation text,
            fav_height integer,
            fav_width integer
        )''')

        nb_users = len(user_data.keys())

        check = True
        # check if all data are usable
        if nb_users > 1:
            for pseudo in user_data.keys() :
                for column in list_columns :
                    if user_data[pseudo][column] is None :
                        check = False
                        return
        else :
            check = False

        if check :
            for pseudo in user_data.keys() :
                c.execute("SELECT * FROM users WHERE pseudo=?", (pseudo,))
                if c.fetchone():
                    c.execute("UPDATE users SET fav_color=?, fav_orientation=?, fav_height=?, fav_width=? WHERE pseudo=?",
                              (user_data[pseudo]['fav_color'], user_data[pseudo]['fav_orientation'], user_data[pseudo]['fav_height'], user_data[pseudo]['fav_width'], pseudo))
                    conn.commit()
                else :
                    c.execute("INSERT INTO users VALUES (?, ?, ?, ?, ?)", (pseudo, user_data[pseudo]['fav_color'], user_data[pseudo]['fav_orientation'], user_data[pseudo]['fav_height'], user_data[pseudo]['fav_width']))
                    conn.commit()
                print(f"User {pseudo} saved to database successfully")
            conn.close()
        else :
            print("Invalid User data")

    except Exception as e:
        print("Error saving user to database : ", e)

save_metadata(user_data)
print(user_data)

# Notebook 5

In [None]:
from mysql.connector import pooling
from tqdm import tqdm
from sklearn.metrics.pairwise import euclidean_distances
import spacy
import pandas as pd
from dotenv import load_dotenv
load_dotenv()

In [None]:
preferences = {
    'Make': '',
    'ImageWidth': '',
    'ImageHeight': '',
    'Orientation': 1,
    'dominant_color': '#73AD3D',
    'tags': ['vase', 'toilet']
}

# Set SQL variables
sql_host = os.getenv("SQL_HOST")
sql_user = os.getenv("SQL_USER")
sql_password = os.getenv("SQL_PASSWORD")
sql_database = os.getenv("SQL_DATABASE")

# set the database config
config = {
    'user': sql_user,
    'password': sql_password,
    'host': sql_host,
    'port': '3306',
    'database': sql_database,
}

In [None]:
# Create a connection pool
connection_pool = pooling.MySQLConnectionPool(pool_name="mypool",
                                              pool_size=2,
                                              **config)

In [None]:
def get_metadata_from_mariadb_db():
    """
    Get the metadata from the MariaDB database

    :return: A pandas DataFrame with the metadata
    """
    # Open a connection to the database
    conn = connection_pool.get_connection()
    # Create a cursor
    c = conn.cursor()

    # Retrieve the metadata
    c.execute("""
        SELECT filename, GROUP_CONCAT(CONCAT(mkey, '\t', mvalue) SEPARATOR '\n') AS metadata
        FROM metadata
        GROUP BY filename;
    """)
    metadata = c.fetchall()

    # Close the connection
    conn.close()

    # Create an empty DataFrame with the desired columns
    columns = ['filename', 'Make', 'Software', 'ImageWidth', 'ImageHeight', 'Orientation', 'DateTimeOriginal',
               'dominant_color', 'tags']
    df = pd.DataFrame(columns=columns)

    # Fill the DataFrame with the metadata
    for image in tqdm(metadata, desc="Get metadata from database"):
        try:
            props = {'filename': image[0]}
            metadata_str = image[1].split('\n')
            for prop in metadata_str:
                if prop:
                    k, value = prop.split('\t')
                    if k in columns[1:]:
                        if k == 'dominant_color':
                            color_list = eval(value)
                            color_list = [c[0] for c in color_list]
                            props[k] = color_list
                        elif k == 'tags':
                            props[k] = eval(value)
                        else:
                            props[k] = value
            df = df.append(props, ignore_index=True)
        except Exception as e:
            print(e, image)

    return df

In [None]:
def hex_to_rgb(color):
    try:
        # remove the # from the color
        color = color[1:]
        # convert the color to rgb values
        rgb = tuple(int(color[i:i + 2], 16) for i in (0, 2, 4))
        return rgb
    except:
        return 0, 0, 0


def get_clean_preferences(df_preferences):
    # remove the rows with nan in dominant_color
    df_preferences = df_preferences.dropna(subset=['dominant_color'])
    # split dominant color into 4 columns and remove the dominant_color column
    # convert the tags column to a list of strings
    # Replace all NaN values with empty strings with the fillna() method
    df_preferences = df_preferences.fillna(0)
    # convert colors to rgb values
    df_preferences['dominant_color'] = df_preferences['dominant_color'].apply(lambda x: hex_to_rgb(x))
    # replace all 0 values with empty strings
    df_preferences['dominant_color'] = df_preferences['dominant_color'].replace(0, '')

    return df_preferences


def get_clean_dataset():
    metadata = get_metadata_from_mariadb_db()
    df_metadata = pd.DataFrame(metadata)
    # remove the rows with nan in dominant_color
    df_metadata = df_metadata.dropna(subset=['dominant_color'])
    # split dominant color into 4 columns and remove the dominant_color column
    if 'dominant_color' in df_metadata.columns:
        df_metadata['color1'] = df_metadata['dominant_color'].apply(lambda x: x[0] if len(x) >= 1 else 0)
        df_metadata['color2'] = df_metadata['dominant_color'].apply(lambda x: x[1] if len(x) >= 2 else 0)
        df_metadata['color3'] = df_metadata['dominant_color'].apply(lambda x: x[2] if len(x) == 3 else 0)
        df_metadata['color4'] = df_metadata['dominant_color'].apply(lambda x: x[3] if len(x) == 4 else 0)
        # convert colors to rgb values
        df_metadata['color1'] = df_metadata['color1'].apply(lambda x: hex_to_rgb(x) if x else (0, 0, 0))
        df_metadata['color2'] = df_metadata['color2'].apply(lambda x: hex_to_rgb(x) if x else (0, 0, 0))
        df_metadata['color3'] = df_metadata['color3'].apply(lambda x: hex_to_rgb(x) if x else (0, 0, 0))
        df_metadata['color4'] = df_metadata['color4'].apply(lambda x: hex_to_rgb(x) if x else (0, 0, 0))
        df_metadata = df_metadata.drop('dominant_color', axis=1)
    else:
        df_metadata['color1'] = 0
        df_metadata['color2'] = 0
        df_metadata['color3'] = 0
        df_metadata['color4'] = 0

    # convert the tags column to a list of strings
    df_metadata = df_metadata.fillna(0)
    # remove all columns except filename, tags, color1, color2, color3, color4, Make, Width, Height
    df_metadata = df_metadata[
        ['filename', 'Make', 'ImageWidth', 'ImageHeight', 'Orientation', 'DateTimeOriginal', 'tags', 'color1', 'color2',
         'color3', 'color4']]
    # replace all 0 values with empty strings
    df_metadata['Make'] = df_metadata['Make'].replace(0, '')

    return df_metadata

In [None]:
df_pref = pd.DataFrame([preferences])
df_preferences = get_clean_preferences(df_pref)
df_preferences.head()

In [None]:
df_metadata = get_clean_dataset()

In [None]:
df_metadata.head()

# Color Similarity

In [None]:
def recommend_colors(df_metadata, df_preferences, n=0):
    # Load the dataset into a Pandas DataFrame
    data = df_metadata.copy()

    # Extract the individual r, g, and b values from tupbles in the color columns
    data[['r1', 'g1', 'b1']] = pd.DataFrame(data['color1'].tolist(), index=data.index)
    data[['r2', 'g2', 'b2']] = pd.DataFrame(data['color2'].tolist(), index=data.index)
    data[['r3', 'g3', 'b3']] = pd.DataFrame(data['color3'].tolist(), index=data.index)
    data[['r4', 'g4', 'b4']] = pd.DataFrame(data['color4'].tolist(), index=data.index)

    # Normalize the r, g, and b columns to be between 0 and 1
    data[['r1', 'g1', 'b1', 'r2', 'g2', 'b2', 'r3', 'g3', 'b3', 'r4', 'g4', 'b4']] = data[['r1', 'g1', 'b1', 'r2', 'g2',
                                                                                           'b2', 'r3', 'g3', 'b3', 'r4',
                                                                                           'g4', 'b4']] / 255

    # Normalize the input RGB color to be between 0 and 1
    r, g, b = df_preferences['dominant_color'][0]
    r_norm, g_norm, b_norm = r / 255, g / 255, b / 255

    # Compute the Euclidean distance between the input color and all the colors in the dataset
    data['similarity_dominant_color'] = euclidean_distances(
        [[r_norm, g_norm, b_norm, r_norm, g_norm, b_norm, r_norm, g_norm, b_norm, r_norm, g_norm, b_norm]],
        data[['r1', 'g1', 'b1', 'r2', 'g2', 'b2', 'r3', 'g3', 'b3', 'r4', 'g4', 'b4']])[0]

    # Sort the dataset by Euclidean distance in ascending order and return the top 10 closest matches
    if n == 0:
        closest_matches = data.sort_values('similarity_dominant_color', ascending=True)[
            ['filename', 'color1', 'color2', 'color3', 'color4', 'similarity_dominant_color']]
    else:
        closest_matches = data.sort_values('similarity_dominant_color', ascending=True).head(n)[
            ['filename', 'color1', 'color2', 'color3', 'color4', 'similarity_dominant_color']]

    return closest_matches


In [None]:
recommend_colors(df_metadata, df_preferences)  # OK

# Tag Similarity

In [None]:
def recommend_tags(df_metadata, df_preferences, n=0, nlp=None):
    # Load the spaCy model if it hasn't been loaded
    if not nlp:
        nlp = spacy.load("en_core_web_md")

    # Define the preferences list and the dataframe
    preferences = df_preferences['tags'][0]
    # Load dataset with words and drop duplicate rows
    df = df_metadata.copy()
    df = df.dropna(subset=["tags"]).reset_index(drop=True)
    # replace int with empty list
    df['tags'] = df['tags'].apply(lambda x: x if x else [])

    # Precompute the similarity between each tag word and each preference word
    similarity_dict = {}
    for tag_word in set([word for tags in df['tags'] for word in tags]):
        for pref_word in set(preferences):
            similarity_dict[(tag_word, pref_word)] = nlp(tag_word).similarity(nlp(pref_word))

    # Compute the average similarity for each row in the dataframe
    similarities = []
    for tags in df['tags']:
        sum_similarity = 0
        for tag_word in tags:
            for pref_word in preferences:
                sum_similarity += similarity_dict[(tag_word, pref_word)]
        avg_similarity = sum_similarity / (len(tags) * len(preferences)) if len(tags) > 0 else 0
        similarities.append(avg_similarity)

    # Add the similarity scores to a new column in the dataframe
    df['similarity_tags'] = similarities
    if n == 0:
        closest_matches = df.sort_values('similarity_tags', ascending=False)[
            ['filename', 'similarity_tags']]
    else:
        closest_matches = df.sort_values('similarity_tags', ascending=False).head(n)[
            ['filename', 'similarity_tags']]

    return closest_matches


In [None]:
recommend_tags(df_metadata, df_preferences)  # OK

# Make Similarity

In [None]:
def recommend_make(df_metadata, df_preferences, n=0):
    # Load the spaCy model
    nlp = spacy.load("en_core_web_md")

    # Define the preferences list and the dataframe
    make = df_preferences['Make'][0]
    # Load dataset with words and drop duplicate rows
    df = df_metadata.copy()
    df = df.dropna(subset=["Make"]).reset_index(drop=True)

    # Convert make and Make to document objects
    make_doc = nlp(make)
    df['Make'] = df['Make'].apply(nlp)

    # Compute the cosine similarity between the make preferences and all the makes in the dataset
    similarities = [make_doc.similarity(doc) for doc in df['Make']]

    # Add the similarity scores to a new column in the dataframe
    df['similarity_make'] = similarities
    if n == 0:
        closest_matches = df.sort_values('similarity_make', ascending=False)[
            ['filename', 'similarity_make']]
    else:
        closest_matches = df.sort_values('similarity_make', ascending=False).head(n)[
            ['filename', 'similarity_make']]

    return closest_matches


In [None]:
recommend_make(df_metadata, df_preferences)  # OK

# Orientation Similarity

In [None]:
def recommend_orientation(df_metadata, df_preferences, n=0):
    # Define the preferences list and the dataframe
    orientation = df_preferences['Orientation'][0]
    # Load dataset with words and drop duplicate rows
    df = df_metadata.dropna(subset=["Orientation"]).reset_index(drop=True)
    # if Orientation contain '' or '0' or '1' then replace with 0 or 1
    df['Orientation'] = df['Orientation'].apply(lambda x: 0 if x == '' or x == '0' else 1)

    # Convert the Orientation column to integer type
    df['Orientation'] = df['Orientation'].astype(int)

    # Orientation is 0 or 1, so we can just subtract the preference from the orientation
    df['similarity_orientation'] = df['Orientation'].apply(lambda x: abs(x - orientation))

    # sort by similarity
    if n > 0:
        closest_matches = df.sort_values('similarity_orientation', ascending=False).head(n)[
            ['filename', 'similarity_orientation']]
    else:
        closest_matches = df.sort_values('similarity_orientation', ascending=False)[
            ['filename', 'similarity_orientation']]

    return closest_matches


In [None]:
recommend_orientation(df_metadata, df_preferences)  # OK

# Size Similarity

In [None]:
def recommend_size(df_metadata, df_preferences, n=0):
    # Define the preferences list and the dataframe
    width = int(df_preferences['ImageWidth'][0])
    height = int(df_preferences['ImageHeight'][0])
    # Load dataset with words and drop duplicate rows
    df = df_metadata.dropna(subset=["ImageWidth", "ImageHeight"]).reset_index(drop=True)

    # Convert the ImageWidth and ImageHeight column to integer type
    df[['ImageWidth', 'ImageHeight']] = df[['ImageWidth', 'ImageHeight']].astype(int)

    # Compute the product of width and height outside the loop
    product = width * height

    # Use apply method to compute similarity score for each row
    df['similarity_size'] = df.apply(lambda x: 1 - abs(product - (x['ImageWidth'] * x['ImageHeight'])) / product, axis=1)

    if n == 0:
        closest_matches = df.sort_values('similarity_size', ascending=False)[
            ['filename', 'similarity_size']]
    else:
        closest_matches = df.sort_values('similarity_size', ascending=False).head(n)[
            ['filename', 'similarity_size']]

    return closest_matches


In [None]:
recommend_size(df_metadata, df_preferences)  # OK

In [None]:
def recommend(df_metadata, df_preferences, n=0):
    # Assign weights to properties based on user preferences
    weights = {
        'Make': float(5.0),
        'ImageWidth': float(1.0),
        'ImageHeight': float(1.0),
        'Orientation': float(2.0),
        'dominant_color': float(3.0),
        'tags': float(5.0)
    }

    # Create a dictionary with the preferences and the corresponding recommendation methods
    preference_methods = {
        'Make': recommend_make,
        'ImageWidth': recommend_size,
        'ImageHeight': recommend_size,
        'Orientation': recommend_orientation,
        'dominant_color': recommend_colors,
        'tags': recommend_tags
    }

    # Remove preferences with no values
    preferences = {k: v for k, v in df_preferences.squeeze().to_dict().items() if v != ''}

    # Calculate the sum of the weights
    weights_sum = 0
    for key in weights:
        weights_sum += weights[key]
    for key in weights:
        weights[key] = weights[key] / weights_sum

    # Calculate similarity score for each property
    df_metadata['similarity_score'] = 0.0
    for preference, value in preferences.items():
        method = preference_methods[preference]
        similarity = method(df_metadata, df_preferences, n)[f'similarity_{preference.lower()}'].astype(float)
        df_metadata['similarity_score'] += similarity * (weights[preference] / weights_sum)

    # Replace NaN values in the 'similarity_score' column with 0
    df_metadata['similarity_score'].fillna(0, inplace=True)

    # Sort by similarity score
    if n == 0:
        closest_matches = df_metadata.sort_values('similarity_score', ascending=False)[
            ['filename', 'similarity_score']]
    else:
        closest_matches = df_metadata.sort_values('similarity_score', ascending=False).head(n)[
            ['filename', 'similarity_score']]

    return closest_matches


In [None]:
recommend(df_metadata, df_preferences)  # OK