# Prerequisite

In [None]:
%pip install torch transformers ipyplot datasets pymysql

# Init Connection

In [None]:
import pymysql
def get_connection():
    connection = pymysql.connect(
        host = "greptimedb",
        port = 4002,
        user = "root",
        database = "public",
    )
    return connection
c = get_connection()
cursor = c.cursor()

# Create Table with Vector Column

In [None]:
cursor.execute("""
CREATE TABLE IF NOT EXISTS embedded_images_large(
    ts TIMESTAMP TIME INDEX DEFAULT CURRENT_TIMESTAMP,
    image_id INT PRIMARY KEY,
    embedding VECTOR(512));
"""
)

# Prepare Model and Dataset

Note that loading the model may take minutes.

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel

import datasets

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

imagenet_datasets = datasets.load_dataset('zh-plus/tiny-imagenet', split='train')

def encode_images_to_embeddings(images):
    # accept a list of images and return the image embeddings
    with torch.no_grad():
        inputs = processor(images=images, return_tensors="pt")
        image_features = model.get_image_features(**inputs)
        return image_features.cpu().detach().numpy()

def encode_text_to_embedding(text):
    # accept a text and return the text embedding
    with torch.no_grad():
        inputs = processor(text=text, return_tensors="pt")
        text_features = model.get_text_features(**inputs)
        return text_features.cpu().detach().numpy()[0]


def embedding_s(embedding):
    return f"[{','.join(map(str, embedding))}]"

In [None]:
imagenet_datasets[0]

# Store Embeddings of Images

Note that calculating the embedding of 100,000 images may take more than half an hour, so you can prepare the data in advance.

In [None]:
imagenet_images = [i['image'] for i in imagenet_datasets]

In [None]:
def insert_images(images_embedding, begin_i):
    for i in range(len(images_embedding)):
        embedding = embedding_s(images_embedding[i])
        cursor.execute(f"""
INSERT INTO embedded_images_large VALUES (DEFAULT, {i+begin_i}, '{embedding}');
        """);

batch_size = 1000
batch = int(len(imagenet_images) / batch_size)

for b_i in range(batch):
    image_batch = imagenet_images[b_i*batch_size:b_i*batch_size+batch_size]
    images_embedding = encode_images_to_embeddings(image_batch)
    insert_images(images_embedding, b_i*batch_size)

# Search

In [None]:
import numpy as np
from PIL import Image
import ipyplot
import time

def search(query, k):
    query_embedding = embedding_s(encode_text_to_embedding(query))
    start = time.time()
    cursor.execute(f"""
SELECT image_id, cos_distance(embedding, '{query_embedding}') AS distance
FROM embedded_images_large
ORDER BY distance
LIMIT {k};
    """);
    res = cursor.fetchall()
    print(f"Time taken: {time.time() - start}")
    return res

res = search("fire", 10)
similar_images = []
similarities = []

for image_id, d in res:
    image = imagenet_images[image_id]
    if image.mode == 'L':
        image = image.convert('RGB')
    np_image = np.array(image)
    similar_images.append(np_image)
    similarities.append(round(1 - d, 3))

ipyplot.plot_images(similar_images, labels=similarities)