# Example of Image Search

It is an example of image search using [OpenAI CLIP](https://huggingface.co/docs/transformers/model_doc/clip) and TiDB Serverless Vector Search.

We will use the CLIP model to encode the image to a 512-dimensional vector and store them in TiDB Serverless. Then use the same model to encode the text query and search for the most similar images in TiDB Serverless.

## Install dependencies


In [None]:
%pip install -q torch transformers requests ipyplot datasets sqlalchemy pymysql tidb_vector

## Prepare the environment

> **Note:**
>
> - We already set the environment variables for you in the TiDB Lab. But if you want to run this example in your local environment, you can refer to the [Prerequisites](https://github.com/pingcap/tidb-vector-python/blob/main/examples/README.md#prerequisites) section to set up the environment.
> - In this example, we use CLIP to generate text and image embeddings with 512 dimensions.


In [None]:
import os

TIDB_HOST = os.getenv("TIDB_HOST")
TIDB_USERNAME = os.getenv("TIDB_USERNAME")
TIDB_PASSWORD = os.getenv("TIDB_PASSWORD")

CLIP_DIMENSION = 512

## Initial the Database and Table

In [None]:
from sqlalchemy import URL, create_engine, Column, Integer
from sqlalchemy.orm import declarative_base, sessionmaker
from tidb_vector.sqlalchemy import VectorType

engine = create_engine(URL(
    "mysql+pymysql",
    username=TIDB_USERNAME,
    password=TIDB_PASSWORD,
    host=TIDB_HOST,
    port=4000,
    database="test",
    query={"ssl_verify_cert": True, "ssl_verify_identity": True},
))

Session = sessionmaker(bind=engine)
Base = declarative_base()

class ImageSearchTest(Base):
    __tablename__ = "image_search_test"

    id = Column(Integer, primary_key=True)
    image_id = Column(Integer)
    embedding = Column(VectorType(CLIP_DIMENSION))

Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)

## Initial CLIP model

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel


model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


## Load test images

In [None]:
import datasets

imagenet_datasets = datasets.load_dataset('theodor1289/imagenet-1k_tiny', split='train')

In [None]:
# inspect the imagenet datasets
imagenet_datasets[0]

extract the images

In [None]:
import ipyplot

imagenet_images = [i['image'] for i in imagenet_datasets]
ipyplot.plot_images(imagenet_images, max_images=20, img_width=100)

## Define the encode function and other helper functions

In [None]:
def encode_images_to_embeddings(images):
    # accept a list of images and return the image embeddings
    with torch.no_grad():
        inputs = processor(images=images, return_tensors="pt")
        image_features = model.get_image_features(**inputs)
        return image_features.cpu().detach().numpy()

def encode_text_to_embedding(text):
    # accept a text and return the text embedding
    with torch.no_grad():
        inputs = processor(text=text, return_tensors="pt")
        text_features = model.get_text_features(**inputs)
        return text_features.cpu().detach().numpy()[0]


## Store the images and their corresponding image embeddings in TiDB Serverless

In [None]:
images_embedding = encode_images_to_embeddings(imagenet_images)
objects = []

for i, embedding in enumerate(images_embedding):
    img = imagenet_images[i]
    objects.append(
        ImageSearchTest(
            image_id=i,
            embedding=embedding
        )
    )

with Session() as session:
    session.add_all(objects)
    session.commit()

## Search for similar images using the text query

In [None]:
from sqlalchemy import asc

query_text = "dog"
query_text_embedding = encode_text_to_embedding(query_text)

with Session() as session:
    results = session.query(
        ImageSearchTest,
        ImageSearchTest.embedding.cosine_distance(query_text_embedding).label("distance"),
    ).order_by(
        asc("distance")
    ).limit(5).all()


    similar_images = []
    similarities = []
    for obj, d in results:
        similar_images.append(imagenet_images[obj.image_id])
        similarities.append(round(1 - d, 3))

# display the similar images
ipyplot.plot_images(similar_images, labels=similarities, img_width=100)
