# Prerequisite

In [None]:
%pip install torch transformers ipyplot datasets pymysql

# Init Connection

In [None]:
import pymysql
def get_connection():
    connection = pymysql.connect(
        host = "127.0.0.1",
        port = 4002,
        user = "root",
        database = "public",
    )
    return connection
c = get_connection()
cursor = c.cursor()

# Create Table with Vector Column

In [None]:
cursor.execute("""
CREATE TABLE IF NOT EXISTS embedded_images(
    ts TIMESTAMP TIME INDEX DEFAULT CURRENT_TIMESTAMP,
    image_id INT PRIMARY KEY,
    embedding VECTOR(512));
"""
)

# Prepare Model and Dataset

Note that loading the model may take minutes.

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel
import datasets

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

imagenet_datasets = datasets.load_dataset('theodor1289/imagenet-1k_tiny', split='train')

def encode_images_to_embeddings(images):
    # accept a list of images and return the image embeddings
    with torch.no_grad():
        inputs = processor(images=images, return_tensors="pt")
        image_features = model.get_image_features(**inputs)
        return image_features.cpu().detach().numpy()

def encode_text_to_embedding(text):
    # accept a text and return the text embedding
    with torch.no_grad():
        inputs = processor(text=text, return_tensors="pt")
        text_features = model.get_text_features(**inputs)
        return text_features.cpu().detach().numpy()[0]

# Inspect Images

Inspect sample of the datasets.

In [None]:
imagenet_datasets[0]

Inspect sample of images.

In [None]:
import ipyplot

imagenet_images = [i['image'] for i in imagenet_datasets]
ipyplot.plot_images(imagenet_images, max_images=20, img_width=100)

# Store Embedding into Vector Table

In [None]:
def embedding_s(embedding):
    return f"[{','.join(map(str, embedding))}]"

def insert_image(i):
    embedding = embedding_s(images_embedding[i])
    cursor.execute(f"""
INSERT INTO embedded_images VALUES (DEFAULT, {i}, '{embedding}');
    """);

images_embedding = encode_images_to_embeddings(imagenet_images)

for i in range(len(images_embedding)):
    insert_image(i)

# Search

In [None]:
def search(query, k):
    query_embedding = embedding_s(encode_text_to_embedding(query))
    cursor.execute(f"""
SELECT image_id, vec_cos_distance(embedding, '{query_embedding}') AS distance
FROM embedded_images
ORDER BY distance
LIMIT {k};
    """);
    return cursor.fetchall()

res = search("dog", 5)
similar_images = []
similarities = []

for image_id, d in res:
    similar_images.append(imagenet_images[image_id])
    similarities.append(round(1 - d, 3))

ipyplot.plot_images(similar_images, labels=similarities, img_width=100)