# LanceDB Myntra Fashion Search Engine

## Download Data
- For this project you need to download the [Myntra Fashion Product Dataset]( https://www.kaggle.com/datasets/hiteshsuthar101/myntra-fashion-product-dataset) from Kaggle and store it here in the `input` folder.
- Note that while creating the table you shall pass the path of the folder in which the images are present, example `/content/input/Images/Images`

## Preliminaries

In [None]:
%%capture

!pip install lancedb
!pip install open_clip_torch

In [None]:
import os
import pandas as pd
from PIL import Image
from pathlib import Path
from random import sample

import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry

from typing import Any

## Embedding Model

In [None]:
def register_model(model_name: str) -> Any:
    """
    Register a model with the given name using LanceDB's EmbeddingFunctionRegistry.

    Args:
        model_name (str): The name of the model to register.

    Returns:
        model: The registered model instance.

    Usage:
    >>> model = register_model("open-clip")
    """
    registry = EmbeddingFunctionRegistry.get_instance()
    model = registry.get(model_name).create()
    return model

## Schema

In [None]:
# Register the OpenAI CLIP model
clip = register_model("open-clip")


class Myntra(LanceModel):
    """
    Represents a Myntra Schema.

    Attributes:
        vector (Vector): The vector representation of the item.
        image_uri (str): The URI of the item's image.
    """

    vector: Vector(clip.ndims()) = clip.VectorField()
    image_uri: str = clip.SourceField()

    @property
    def image(self):
        return Image.open(self.image_uri)


# Function to map schema name to schema class
def get_schema_by_name(schema_name: str) -> Any:
    """
    Retrieves the schema object based on the given schema name.

    Args:
        schema_name (str): The name of the schema.

    Returns:
        object: The schema object corresponding to the given schema name, or None if not found.

    Usage:
    >>> schema = get_schema_by_name("Myntra")
    """
    schema_map = {
        "Myntra": Myntra,
    }
    return schema_map.get(schema_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


open_clip_model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

## Creating a Table

In [None]:
def create_table(
    database: str,
    table_name: str,
    data_path: str,
    mode: str = "create",  # "create", "overwrite"
    schema: Any = Myntra,
    sample_size: int = 100,
) -> None:
    """
    Create a table in the specified vector database and add data to it.

    Args:
        database (str): The name of the database to connect to.
        table_name (str): The name of the table to create.
        data_path (str): The path to the data directory.
        mode (str): The mode for creating the table. Defaults to "create".
        schema (Schema, optional): The schema to use for the table. Defaults to Myntra.
        sample_size (int, optional): The number of images to sample from the data. Defaults to 100.

    Returns:
        None

    Usage:
    >>> create_table(database="lancedb_myntra", table_name="fashion", data_path="input")
    """

    # Connect to the lancedb database
    db = lancedb.connect(database)

    # Check if the table already exists in the database
    if table_name in db and mode != "overwrite":
        print(f"Table {table_name} already exists in the database")
        table = db[table_name]

    # if it does not exist then create a new table
    else:

        print(f"Creating table {table_name} in the database")

        if table_name in db:
            db.drop_table(table_name)

        # Create the table with the given schema
        table = db.create_table(table_name, schema=schema, mode=mode)

        # Define the Path of the images and obtain the Image uri
        p = Path(data_path).expanduser()
        uris = [str(f) for f in p.glob("*.jpg")]
        print(f"Found {len(uris)} images in {p}")

        # Sample sample_size images from the data
        # Increase this value for more accurate results but
        # it will take more time to process embeddings
        uris = sample(uris, sample_size)

        # Add the data to the table
        print(f"Adding {len(uris)} images to the table")
        table.add(pd.DataFrame({"image_uri": uris}))
        print(f"Added {len(uris)} images to the table")

In [None]:
# The data_path should refer to the folder in which the images are located

create_table(
    database="lancedb_myntra",
    table_name="fashion",
    data_path="/content/input/Images/Images",
    mode="overwrite",
)

Creating table fashion in the database
Found 6 images in /content/input/Images/Images
Adding 5 images to the table


100%|██████████| 5/5 [00:01<00:00,  3.77it/s]

Added 5 images to the table





## Vector Search

In [None]:
def run_vector_search(
    database: str,
    table_name: str,
    schema: Any,
    search_query: Any,
    limit: int = 6,
    output_folder: str = "output",
) -> None:
    """
    This function performs a vector search on the specified database and table using the provided search query.
    The search can be performed on either text or image data. The function retrieves the top 'limit' number of results
    and saves the corresponding images in the 'output_folder' directory. The function assumes if the search query ends
    with '.jpg' or '.png', it is an image search, otherwise it is a text search.
    Args:
        database (str): The path to the database.
        table_name (str): The name of the table.
        schema (Schema): The schema to use for converting search results to Pydantic models.
        search_query (Any): The search query, can be text or image.
        limit (int, optional): The maximum number of results to return. Defaults to 6.
        output_folder (str, optional): The folder to save the output images. Defaults to "output".

    Returns:
        None

    Usage:
    >>> run_vector_search(database="lancedb_myntra", table_name="fashion", schema=Myntra, search_query="Black Kurta")

    """

    # Create the output folder if it does not exist
    if os.path.exists(output_folder):
        for file in os.listdir(output_folder):
            os.remove(os.path.join(output_folder, file))
    else:
        os.makedirs(output_folder)

    # Connect to the lancedb database
    db = lancedb.connect(database)

    # Open the table
    table = db.open_table(table_name)

    # Check if the search query is an image or text
    try:
        if search_query.endswith(".jpg") or search_query.endswith(".png"):
            search_query = Image.open(search_query)
        else:
            search_query = search_query
    except AttributeError as e:
        if str(e) == "'JpegImageFile' object has no attribute 'endswith'":
            print(
                "Running via Streamlit, search query is already an array so skipping opening image using Pillow"
            )
        else:
            raise

    # Perform the vector search and retrieve the results
    rs = table.search(search_query).limit(limit).to_pydantic(schema)

    # Save the images to the output folder
    for i in range(limit):
        image_path = os.path.join(output_folder, f"image_{i}.jpg")
        rs[i].image.save(image_path, "JPEG")

After the search is done, the results will be saved in the `output` folder.

## Text Search

Run Text to Image search in the database. Results will be stored in the `Week9/session_1/output` folder.

In [None]:
run_vector_search(
    database="/content/lancedb_myntra",
    table_name="fashion",
    schema=Myntra,
    search_query="White Kurta",
    limit=3,
    output_folder="output",
)

## Image Search

Run Image to Image search in the database. Results will be stored in the `Week9/session_1/output` folder.

In [None]:
run_vector_search(
    database="lancedb_myntra",
    table_name="fashion",
    schema=Myntra,
    search_query="/content/input/Images/Images/0.jpg",
    limit=3,
    output_folder="output",
)

In [None]:
# !rm -rf lancedb_myntra # To delete the DB