# Build a Text-Image Search Engine in Celestica

This notebook illustrates how to build an text-image search engine from scratch using Celestica. Celestica is the most advanced open-source vector database built on top of IPFS/IPLD and supports nearest neighbor embedding search across tens of millions of entries.

### Prepare the data

The dataset used in this demo is a subset of the ImageNet dataset (100 classes, 10 images for each class), and the dataset is available via [Github](https://github.com/towhee-io/examples/releases/download/data/reverse_image_search.zip). 

The dataset is organized as follows:
- **train**: directory of candidate images;
- **test**: directory of test images;
- **reverse_image_search.csv**: a csv file containing an ***id***, ***path***, and ***label*** for each image;

Let's take a quick look:

In [None]:
! curl -L https://github.com/towhee-io/examples/releases/download/data/reverse_image_search.zip -O
! unzip -q -o reverse_image_search.zip

In [None]:
import pandas as pd

df = pd.read_csv('reverse_image_search.csv')
df.head()

### Create a Celestica Collection

Let's first create a `text_image_search` collection.

In [None]:
# TBD

## Text Image Search

In this section, we'll show how to build our text-image search engine using Celestica. The basic idea behind our text-image search is the extract embeddings from images and texts using a deep neural network and compare the embeddings with those stored in Celestica.

### Generate image and text embeddings with CLIP


This operator extracts features for image or text with [CLIP](https://openai.com/blog/clip/) which can generate embeddings for text and image by jointly training an image encoder and text encoder to maximize the cosine similarity.

In [None]:
import cv2
import numpy as np
import torch
from PIL import Image
import clip

def read_image_rgb(path):
    img = cv2.imread(path)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Try with clip_vit_base_patch16
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
def clip_image_embedding(img, model_name='ViT-B/32'):
    model, preprocess = clip.load(model_name, device='cpu')
    img = Image.fromarray(img)
    img = preprocess(img).unsqueeze(0)
    with torch.no_grad():
        features = model.encode_image(img)
    return features.numpy()

def normalize_vector(vec):
    return vec / np.linalg.norm(vec)

# Execute the pipeline
input_path = '../data/teddy.png'
img = read_image_rgb(input_path)
vec = clip_image_embedding(img)
normalized_vec = normalize_vector(vec)

print(normalized_vec)


In [None]:
import clip
import torch
import numpy as np

def clip_text_embedding(text, model_name='ViT-B/32'):
    model, preprocess = clip.load(model_name, device='cpu')
    text_tokenized = clip.tokenize([text]).to('cpu')
    with torch.no_grad():
        features = model.encode_text(text_tokenized)
    return features.numpy()

def normalize_vector(vec):
    return vec / np.linalg.norm(vec)

# Execute the pipeline
input_text = "A teddybear on a skateboard in Times Square."
vec = clip_text_embedding(input_text)
normalized_vec = normalize_vector(vec)

print(normalized_vec)


### Load Image Embeddings into Celestica

We first extract embeddings from images with `ViT-B/32` model and insert the embeddings into Celestica for indexing.

In [None]:
%%time

import csv
import cv2
import numpy as np
import torch
from PIL import Image
import clip
from celestica.client import CelesticaClient

def read_csv(csv_path, encoding='utf-8-sig'):
    with open(csv_path, 'r', encoding=encoding) as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line['id']), line['path']

def read_image_rgb(path):
    img = cv2.imread(path)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def clip_image_embedding(img, model_name='ViT-B/32'):
    model, preprocess = clip.load(model_name, device='cpu')
    img = Image.fromarray(img)
    img = preprocess(img).unsqueeze(0)
    with torch.no_grad():
        features = model.encode_image(img)
    return features.numpy()

def normalize_vector(vec):
    return vec / np.linalg.norm(vec)

def process_csv_file(csv_file):
    url = "celestica:50051"  # Replace with the address and port of your gRPC server
    client = CelesticaClient(url)
    
    for id, path in read_csv(csv_file):
        img = read_image_rgb(path)
        vec = clip_image_embedding(img)
        normalized_vec = normalize_vector(vec)
        #if id == 1:
            #print(f"ID: {id}")
            #print(f"Vec type: {type(vec)}, Vec value: {vec}")
            #print(f"Normalized Vec type: {type(normalized_vec)}, Normalized Vec value: {normalized_vec}")

        client.insert(normalized_vec.tolist(), [id])

csv_file = 'reverse_image_search.csv'
process_csv_file(csv_file)



### Query Matched Images from Celestica

Now that embeddings for candidate images have been inserted into Celestica, we can query across it for nearest neighbors. Because Celestica only outputs image IDs and distance values, we provide a `read_images` function to get the original image based on IDs and display.

In [None]:
import pandas as pd
import cv2
import numpy as np
from celestica.client import CelesticaClient
from IPython.display import display, Image
from PIL import Image as PILImage

def read_image(image_ids):
    df = pd.read_csv('reverse_image_search.csv')
    id_img = df.set_index('id')['path'].to_dict()
    imgs = []
    for image_id in image_ids:
        path = id_img[image_id]
        imgs.append(read_image_rgb(path))
    return imgs

def process_text_query(text):
    # Compute text embedding
    vec = clip_text_embedding(text)
    normalized_vec = normalize_vector(vec)

    # Search using Celestica client
    url = "celestica:50051"  # Replace with the address and port of your gRPC server
    client = CelesticaClient(url)
    knbn = 5
    ef = 10
    neighbours = client.search(normalized_vec.tolist(), knbn, ef)

    # Get image IDs from search results
    image_ids = [neighbour.point_id.index for neighbour in neighbours[0].neighbour]

    # Read images
    images = read_image(image_ids)
    return text, images

text1, images1 = process_text_query("A white dog")
text2, images2 = process_text_query("A black dog")

print(f"Query: {text1}")
for i, img in enumerate(images1):
    display(PILImage.fromarray(img))

print(f"Query: {text2}")
for i, img in enumerate(images2):
    display(PILImage.fromarray(img))

